Unverified Commit 7c1828f8 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Support Low Rank Adaptation (LoRA). (#745)

parent 1442b47e
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import jax
import jax.numpy as jnp
from utils import assert_allclose
from transformer_engine.jax.flax.module import _apply_low_rank_adaptation
from transformer_engine.jax.flax.module import _normalize_axes
from transformer_engine.jax.flax.transformer import LoRAScope
from transformer_engine.jax.flax.transformer import _canonicalize_lora_scope
class TestLoRA:
def reference(x, la, lb, pattern, scale):
out = jnp.einsum(pattern, x, la, lb)
return out * scale
@pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)])
@pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16])
@pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'),
((-1,), (3, 1024), '...h,hkr,krz->...kz')])
@pytest.mark.parametrize('rank', [32, 16])
@pytest.mark.parametrize('alpha', [None, 4, 8])
def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha):
axis, features, pattern = axis_features_pattern
axis = _normalize_axes(axis, len(shape))
shape_in_axis = tuple(shape[ax] for ax in axis)
key = jax.random.key(1124)
key, x_key = jax.random.split(key)
x = jax.random.normal(x_key, shape, dtype)
key, la_key = jax.random.split(key)
la_shape = (*shape_in_axis, *features[:-1], rank)
la = jax.random.normal(la_key, la_shape, dtype)
key, lb_key = jax.random.split(key)
lb_shape = (*features[:-1], rank, features[-1])
lb = jax.random.normal(lb_key, lb_shape, dtype)
out_target = _apply_low_rank_adaptation(x, axis, features, la, lb, alpha)
scale_ref = alpha / rank if alpha is not None else 1.0
out_ref = TestLoRA.reference(x, la, lb, pattern, scale_ref)
assert_allclose(out_target, out_ref, dtype=dtype)
@pytest.mark.parametrize('scope_ref_assert',
[('none', LoRAScope(False, False, False), False),
('all', LoRAScope(True, True, True), False),
('qkv_proj', LoRAScope(True, False, False), False),
('output_proj', LoRAScope(False, True, False), False),
('mlp', LoRAScope(False, False, True), False),
('exclude_qkv_proj', LoRAScope(False, True, True), False),
('exclude_output_proj', LoRAScope(True, False, True), False),
('exclude_mlp', LoRAScope(True, True, False), False),
('messing_up', LoRAScope(), True)])
def test_lora_scope_generator(self, scope_ref_assert):
scope, reference, need_assert = scope_ref_assert
try:
lora_scope = _canonicalize_lora_scope(scope)
assert lora_scope == reference
except AssertionError as ae:
assert need_assert, f"{ae.args}"
......@@ -784,6 +784,7 @@ class MultiHeadAttnAttr:
NUM_GQA_GROUPS = 'num_gqa_groups'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
......@@ -853,6 +854,22 @@ class MultiHeadAttnAttr:
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding',
LORA_SCOPE: 'all'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal',
LORA_SCOPE: 'all'
}]
......@@ -883,6 +900,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none')
fuse_qkv_params = True
transpose_batch_sequence = True
scale_attn_logits = False
......@@ -905,6 +923,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
......@@ -926,6 +945,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
......@@ -969,6 +989,7 @@ class TransformerLayerAttr:
TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
......@@ -1113,6 +1134,16 @@ class TransformerLayerAttr:
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False,
LORA_SCOPE: 'all'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
......@@ -1185,6 +1216,16 @@ class TransformerLayerAttr:
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False,
LORA_SCOPE: 'all'
}]
......@@ -1219,6 +1260,7 @@ class TestTransformer(TestLayer):
layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none')
enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases,
dtype=dtype,
......@@ -1257,6 +1299,7 @@ class TestTransformer(TestLayer):
enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)
......@@ -1282,6 +1325,7 @@ class TestTransformer(TestLayer):
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)
......
......@@ -104,6 +104,31 @@ def _combine_biases(*masks: List[Array]):
return mask
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
"""Low Rank Adaptation Implementation"""
assert len(axis) <= 5
hidden_in_names = 'ijklm'[:len(axis)]
assert len(features) <= 5
hidden_out_names = 'nopqr'[:len(features)]
rank_name = 's'
assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
rank = lora_a_kernel.shape[-1]
scaling = alpha / rank if alpha is not None else 1.0
x_einsum_express = f"...{hidden_in_names}"
lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
output_einsum_express = f"...{hidden_out_names}"
final_einsum_express = f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}" \
f"->{output_einsum_express}"
output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
output = output * scaling
return output
class Softmax(nn.Module): # pylint: disable=too-few-public-methods
r"""
Applies softmax over a mini-batch of inputs.
......@@ -355,6 +380,14 @@ class DenseGeneral(TransformerEngineBase):
bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
......@@ -374,6 +407,9 @@ class DenseGeneral(TransformerEngineBase):
use_bias: bool = True
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = False
......@@ -439,6 +475,32 @@ class DenseGeneral(TransformerEngineBase):
fp8_meta_pkg=fp8_gemm_pkg,
contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel',
self.kernel_init,
lora_a_kernel_init_shape,
jnp.float32,
axes=lora_a_kernel_axes)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(self.dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel',
nn.initializers.zeros,
lora_b_kernel_shape,
jnp.float32,
axes=lora_b_kernel_axes)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
y += _apply_low_rank_adaptation(inputs, axis, features, lora_a_kernel, lora_b_kernel,
self.low_rank_adaptation_alpha)
if bias is not None:
bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
y += jnp.reshape(bias, bias_shape)
......@@ -502,6 +564,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
return_layernorm_output: bool, default = True
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
......@@ -541,6 +611,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = ()
return_layernorm_output: bool = True
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
......@@ -650,6 +723,32 @@ class LayerNormDenseGeneral(TransformerEngineBase):
fp8_meta_pkg=fp8_meta_package,
contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel',
self.kernel_init,
lora_a_kernel_init_shape,
jnp.float32,
axes=lora_a_kernel_axes)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(self.dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel',
nn.initializers.zeros,
lora_b_kernel_shape,
jnp.float32,
axes=lora_b_kernel_axes)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
z += _apply_low_rank_adaptation(y, axis, features, lora_a_kernel, lora_b_kernel,
self.low_rank_adaptation_alpha)
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('bias',
......@@ -745,6 +844,14 @@ class LayerNormMLP(TransformerEngineBase):
Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`.
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
......@@ -791,6 +898,9 @@ class LayerNormMLP(TransformerEngineBase):
intermediate_dropout_rng_name: str = 'dropout'
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
transpose_batch_sequence: bool = True
......@@ -856,11 +966,13 @@ class LayerNormMLP(TransformerEngineBase):
use_fused_ln_geglu_mlp = fuse_layernorm \
and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3)
and (self.intermediate_dropout_rate < 1e-3) \
and not self.enable_low_rank_adaptation
use_fused_ln_gelu_mlp = fuse_layernorm \
and self.use_bias and is_gelu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3)
and (self.intermediate_dropout_rate < 1e-3) \
and not self.enable_low_rank_adaptation
# LayerNorm
if self.enable_layernorm:
......@@ -999,6 +1111,37 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_pkg=gemm1_fp8_meta_package,
contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
wi_lora_a_kernel_shape = (*kernel_1_shape[:len(axis)], num_activations,
self.low_rank_adaptation_dim)
wi_lora_a_kernel_init_shape = (kernel_1_each_shape[0], num_activations,
self.low_rank_adaptation_dim)
wi_lora_a_kernel_init_each_shape = (kernel_1_each_shape[0],
self.low_rank_adaptation_dim)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
wi_lora_a_kernel = nn_partitioning.param_with_axes('wi_lora_a_kernel',
kernel_1_init,
num_activations,
-2,
wi_lora_a_kernel_init_each_shape,
jnp.float32,
axes=wi_lora_a_kernel_axes)
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype)
wi_lora_b_kernel_shape = (num_activations, self.low_rank_adaptation_dim,
self.intermediate_dim)
wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
wi_lora_b_kernel = nn_partitioning.param_with_axes('wi_lora_b_kernel',
nn.initializers.zeros,
wi_lora_b_kernel_shape,
jnp.float32,
axes=wi_lora_b_kernel_axes)
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)
x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel,
wi_lora_b_kernel, self.low_rank_adaptation_alpha)
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('wi_bias',
......@@ -1042,6 +1185,28 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_pkg=gemm2_fp8_meta_package,
contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
wo_lora_a_kernel = nn_partitioning.param_with_axes('wo_lora_a_kernel',
self.kernel_init,
wo_lora_a_kernel_shape,
jnp.float32,
axes=wo_lora_a_kernel_axes)
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)
wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
wo_lora_b_kernel = nn_partitioning.param_with_axes('wo_lora_b_kernel',
nn.initializers.zeros,
wo_lora_b_kernel_shape,
jnp.float32,
axes=wo_lora_b_kernel_axes)
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)
out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel,
wo_lora_b_kernel, self.low_rank_adaptation_alpha)
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes('wo_bias',
......
......@@ -637,6 +637,53 @@ def rotary_pos_emb(x: Array,
return consecutive_impl()
class LoRAScope: # pylint: disable=too-few-public-methods
"""LoRA Scope"""
def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
self.qkv_proj = qkv_proj
self.output_proj = output_proj
self.mlp = mlp
def __eq__(self, other):
return (self.qkv_proj, self.output_proj, self.mlp) == \
(other.qkv_proj, other.output_proj, other.mlp)
def _canonicalize_lora_scope(scope):
SCOPE_NONE = 'none'
SCOPE_ALL = 'all'
SCOPE_QKV_PROJ = 'qkv_proj'
SCOPE_OUTPUT_PROJ = 'output_proj'
SCOPE_MLP = 'mlp'
SCOPE_EX_QKV_PROJ = 'exclude_qkv_proj'
SCOPE_EX_OUTPUT_PROJ = 'exclude_output_proj'
SCOPE_EX_MLP = 'exclude_mlp'
scope = SCOPE_NONE if scope is None else scope
scope = scope.lower()
assert scope in [
SCOPE_NONE, SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_OUTPUT_PROJ, SCOPE_MLP, SCOPE_EX_QKV_PROJ,
SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP
]
lora_scope = LoRAScope()
if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
lora_scope.qkv_proj = True
if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
lora_scope.output_proj = True
if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
lora_scope.mlp = True
return lora_scope
class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
r"""
Multi-head Attention (MHA), including Query,
......@@ -723,6 +770,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
num_heads: int, default = None
......@@ -777,6 +833,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
......@@ -914,6 +973,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)
lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)
if self.fuse_qkv_params:
if is_qkvpack:
qkv_proj, ln_out = LayerNormDenseGeneral(
......@@ -932,6 +993,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(W_JOINED_AXES, W_TP_AXES),
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp,
name='qkv',
......@@ -954,6 +1018,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(W_TP_AXES,),
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
......@@ -972,6 +1039,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(W_JOINED_AXES, W_TP_AXES),
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
name='kv',
dtype=self.dtype)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
......@@ -986,6 +1056,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(W_TP_AXES,),
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype)
query, ln_out = LayerNormDenseGeneral(
enable_layernorm=self.input_layernorm,
......@@ -1002,6 +1075,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(W_TP_AXES,),
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp,
......@@ -1142,6 +1218,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias,
bias_init=self.bias_init,
bias_axes=(W_NO_SHARD_AXES,),
enable_low_rank_adaptation=lora_scope.output_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype,
name='out')(x)
out = checkpoint_name(out, 'out_proj')
......@@ -1379,6 +1458,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
'exclude_output_proj', 'exclude_mlp']
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
......@@ -1434,6 +1523,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32
drop_path: float = 0.0
fuse_qkv_params: bool = True
......@@ -1579,6 +1671,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
fuse_qkv_params=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init,
use_bias=self.use_bias,
......@@ -1646,6 +1741,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
......@@ -1674,6 +1772,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_input = with_sharding_constraint_by_logical_axes(
mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)
# MlpBlock
residual = mlp_input
z, ln_out = LayerNormMLP(
......@@ -1697,6 +1797,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init=self.bias_init,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,),
enable_low_rank_adaptation=lora_scope.mlp,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
......
......@@ -131,6 +131,9 @@ class Linear(TransformerEngineBaseLayer):
use_bias: bool = True
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE
......@@ -147,6 +150,9 @@ class Linear(TransformerEngineBaseLayer):
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence)
......@@ -174,6 +180,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
......@@ -201,6 +210,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
return_layernorm_output=self.return_layernorm_output,
axis=self.axis,
dtype=self.dtype,
......@@ -232,6 +244,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',)
intermediate_dropout_rate: float = 0.1
......@@ -263,6 +278,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes_1=self.bias_axes_1,
bias_axes_2=self.bias_axes_2,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
return_layernorm_output=self.return_layernorm_output,
activations=self.activations,
intermediate_dropout_rate=self.intermediate_dropout_rate,
......
......@@ -137,6 +137,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
......@@ -208,6 +211,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
......@@ -262,6 +268,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
......@@ -332,6 +341,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment