Unverified Commit 7c1828f8 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

Support Low Rank Adaptation (LoRA). (#745)

parent 1442b47e
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import jax
import jax.numpy as jnp
from utils import assert_allclose
from transformer_engine.jax.flax.module import _apply_low_rank_adaptation
from transformer_engine.jax.flax.module import _normalize_axes
from transformer_engine.jax.flax.transformer import LoRAScope
from transformer_engine.jax.flax.transformer import _canonicalize_lora_scope
class TestLoRA:
def reference(x, la, lb, pattern, scale):
out = jnp.einsum(pattern, x, la, lb)
return out * scale
@pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)])
@pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16])
@pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'),
((-1,), (3, 1024), '...h,hkr,krz->...kz')])
@pytest.mark.parametrize('rank', [32, 16])
@pytest.mark.parametrize('alpha', [None, 4, 8])
def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha):
axis, features, pattern = axis_features_pattern
axis = _normalize_axes(axis, len(shape))
shape_in_axis = tuple(shape[ax] for ax in axis)
key = jax.random.key(1124)
key, x_key = jax.random.split(key)
x = jax.random.normal(x_key, shape, dtype)
key, la_key = jax.random.split(key)
la_shape = (*shape_in_axis, *features[:-1], rank)
la = jax.random.normal(la_key, la_shape, dtype)
key, lb_key = jax.random.split(key)
lb_shape = (*features[:-1], rank, features[-1])
lb = jax.random.normal(lb_key, lb_shape, dtype)
out_target = _apply_low_rank_adaptation(x, axis, features, la, lb, alpha)
scale_ref = alpha / rank if alpha is not None else 1.0
out_ref = TestLoRA.reference(x, la, lb, pattern, scale_ref)
assert_allclose(out_target, out_ref, dtype=dtype)
@pytest.mark.parametrize('scope_ref_assert',
[('none', LoRAScope(False, False, False), False),
('all', LoRAScope(True, True, True), False),
('qkv_proj', LoRAScope(True, False, False), False),
('output_proj', LoRAScope(False, True, False), False),
('mlp', LoRAScope(False, False, True), False),
('exclude_qkv_proj', LoRAScope(False, True, True), False),
('exclude_output_proj', LoRAScope(True, False, True), False),
('exclude_mlp', LoRAScope(True, True, False), False),
('messing_up', LoRAScope(), True)])
def test_lora_scope_generator(self, scope_ref_assert):
scope, reference, need_assert = scope_ref_assert
try:
lora_scope = _canonicalize_lora_scope(scope)
assert lora_scope == reference
except AssertionError as ae:
assert need_assert, f"{ae.args}"
...@@ -784,6 +784,7 @@ class MultiHeadAttnAttr: ...@@ -784,6 +784,7 @@ class MultiHeadAttnAttr:
NUM_GQA_GROUPS = 'num_gqa_groups' NUM_GQA_GROUPS = 'num_gqa_groups'
ENABLE_ROPE = 'enable_rotary_pos_emb' ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope'
ATTRS = [{ ATTRS = [{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
...@@ -853,6 +854,22 @@ class MultiHeadAttnAttr: ...@@ -853,6 +854,22 @@ class MultiHeadAttnAttr:
NUM_ATTN_HEADS: 8, NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4, NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal' ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding',
LORA_SCOPE: 'all'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal',
LORA_SCOPE: 'all'
}] }]
...@@ -883,6 +900,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -883,6 +900,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE] attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE] enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD] rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none')
fuse_qkv_params = True fuse_qkv_params = True
transpose_batch_sequence = True transpose_batch_sequence = True
scale_attn_logits = False scale_attn_logits = False
...@@ -905,6 +923,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -905,6 +923,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb, enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method, rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
...@@ -926,6 +945,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -926,6 +945,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb, enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method, rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
fuse_qkv_params=fuse_qkv_params, fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
...@@ -969,6 +989,7 @@ class TransformerLayerAttr: ...@@ -969,6 +989,7 @@ class TransformerLayerAttr:
TRANSPOSE_BS = 'transpose_batch_sequence' TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb' ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method' ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
LORA_SCOPE = 'low_rank_adaptation_scope'
ATTRS = [{ ATTRS = [{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
...@@ -1113,6 +1134,16 @@ class TransformerLayerAttr: ...@@ -1113,6 +1134,16 @@ class TransformerLayerAttr:
ENABLE_ROPE: False, ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False,
LORA_SCOPE: 'all'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
...@@ -1185,6 +1216,16 @@ class TransformerLayerAttr: ...@@ -1185,6 +1216,16 @@ class TransformerLayerAttr:
ENABLE_ROPE: True, ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive', ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False,
LORA_SCOPE: 'all'
}] }]
...@@ -1219,6 +1260,7 @@ class TestTransformer(TestLayer): ...@@ -1219,6 +1260,7 @@ class TestTransformer(TestLayer):
layer_type = attrs[TransformerLayerAttr.LYR_TYPE] layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE] enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD] rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none')
enable_relative_embedding = True enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases, relative_embedding = pax_fiddle.Config(RelativePositionBiases,
dtype=dtype, dtype=dtype,
...@@ -1257,6 +1299,7 @@ class TestTransformer(TestLayer): ...@@ -1257,6 +1299,7 @@ class TestTransformer(TestLayer):
enable_relative_embedding=enable_relative_embedding, enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb, enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method, rotary_pos_emb_group_method=rotary_pos_emb_group_method,
low_rank_adaptation_scope=low_rank_adaptation_scope,
relative_embedding=relative_embedding, relative_embedding=relative_embedding,
drop_path=drop_path, drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence)
...@@ -1282,6 +1325,7 @@ class TestTransformer(TestLayer): ...@@ -1282,6 +1325,7 @@ class TestTransformer(TestLayer):
rotary_pos_emb_group_method=rotary_pos_emb_group_method, rotary_pos_emb_group_method=rotary_pos_emb_group_method,
enable_relative_embedding=enable_relative_embedding, enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module, relative_embedding=relative_embedding_flax_module,
low_rank_adaptation_scope=low_rank_adaptation_scope,
drop_path=drop_path, drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence) transpose_batch_sequence=transpose_batch_sequence)
......
...@@ -104,6 +104,31 @@ def _combine_biases(*masks: List[Array]): ...@@ -104,6 +104,31 @@ def _combine_biases(*masks: List[Array]):
return mask return mask
def _apply_low_rank_adaptation(x, axis, features, lora_a_kernel, lora_b_kernel, alpha):
"""Low Rank Adaptation Implementation"""
assert len(axis) <= 5
hidden_in_names = 'ijklm'[:len(axis)]
assert len(features) <= 5
hidden_out_names = 'nopqr'[:len(features)]
rank_name = 's'
assert lora_a_kernel.shape[-1] == lora_b_kernel.shape[-2]
rank = lora_a_kernel.shape[-1]
scaling = alpha / rank if alpha is not None else 1.0
x_einsum_express = f"...{hidden_in_names}"
lora_a_einsum_express = f"{hidden_in_names}{hidden_out_names[:-1]}{rank_name}"
lora_b_einsum_express = f"{hidden_out_names[:-1]}{rank_name}{hidden_out_names[-1]}"
output_einsum_express = f"...{hidden_out_names}"
final_einsum_express = f"{x_einsum_express},{lora_a_einsum_express},{lora_b_einsum_express}" \
f"->{output_einsum_express}"
output = jnp.einsum(final_einsum_express, x, lora_a_kernel, lora_b_kernel)
output = output * scaling
return output
class Softmax(nn.Module): # pylint: disable=too-few-public-methods class Softmax(nn.Module): # pylint: disable=too-few-public-methods
r""" r"""
Applies softmax over a mini-batch of inputs. Applies softmax over a mini-batch of inputs.
...@@ -355,6 +380,14 @@ class DenseGeneral(TransformerEngineBase): ...@@ -355,6 +380,14 @@ class DenseGeneral(TransformerEngineBase):
bias_axes: Tuple[str, ...], default = () bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh, The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`. only used when :attr:`use_bias=True`.
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
...@@ -374,6 +407,9 @@ class DenseGeneral(TransformerEngineBase): ...@@ -374,6 +407,9 @@ class DenseGeneral(TransformerEngineBase):
use_bias: bool = True use_bias: bool = True
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = () bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
...@@ -439,6 +475,32 @@ class DenseGeneral(TransformerEngineBase): ...@@ -439,6 +475,32 @@ class DenseGeneral(TransformerEngineBase):
fp8_meta_pkg=fp8_gemm_pkg, fp8_meta_pkg=fp8_gemm_pkg,
contracting_dims=(axis, contract_ind)) contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel',
self.kernel_init,
lora_a_kernel_init_shape,
jnp.float32,
axes=lora_a_kernel_axes)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(self.dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel',
nn.initializers.zeros,
lora_b_kernel_shape,
jnp.float32,
axes=lora_b_kernel_axes)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
y += _apply_low_rank_adaptation(inputs, axis, features, lora_a_kernel, lora_b_kernel,
self.low_rank_adaptation_alpha)
if bias is not None: if bias is not None:
bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape
y += jnp.reshape(bias, bias_shape) y += jnp.reshape(bias, bias_shape)
...@@ -502,6 +564,14 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -502,6 +564,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
return_layernorm_output: bool, default = True return_layernorm_output: bool, default = True
Indicate whether to return the output of layer normalization. Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs. If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None layernorm_input_axes: Tuple[str, ...], default = None
...@@ -541,6 +611,9 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -541,6 +611,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
bias_axes: Tuple[str, ...] = () bias_axes: Tuple[str, ...] = ()
return_layernorm_output: bool = True return_layernorm_output: bool = True
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -650,6 +723,32 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -650,6 +723,32 @@ class LayerNormDenseGeneral(TransformerEngineBase):
fp8_meta_pkg=fp8_meta_package, fp8_meta_pkg=fp8_meta_package,
contracting_dims=(axis, contract_ind)) contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
lora_a_kernel_shape = (*kernel_shape[:len(axis)], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_init_shape = (kernel_param_shape[0], *features[:-1],
self.low_rank_adaptation_dim)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_init_shape)
lora_a_kernel = nn_partitioning.param_with_axes('lora_a_kernel',
self.kernel_init,
lora_a_kernel_init_shape,
jnp.float32,
axes=lora_a_kernel_axes)
lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape)
lora_a_kernel = lora_a_kernel.astype(self.dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes('lora_b_kernel',
nn.initializers.zeros,
lora_b_kernel_shape,
jnp.float32,
axes=lora_b_kernel_axes)
lora_b_kernel = lora_b_kernel.astype(self.dtype)
z += _apply_low_rank_adaptation(y, axis, features, lora_a_kernel, lora_b_kernel,
self.low_rank_adaptation_alpha)
bias = None bias = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('bias', bias = nn_partitioning.param_with_axes('bias',
...@@ -745,6 +844,14 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -745,6 +844,14 @@ class LayerNormMLP(TransformerEngineBase):
Dropout probability for the dropout op after the :attr:`activations`. Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = () intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden Dimensions that will share the same dropout mask for hidden
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`.
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1 axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on. An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None layernorm_input_axes: Tuple[str, ...], default = None
...@@ -791,6 +898,9 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -791,6 +898,9 @@ class LayerNormMLP(TransformerEngineBase):
intermediate_dropout_rng_name: str = 'dropout' intermediate_dropout_rng_name: str = 'dropout'
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = () intermediate_hidden_dropout_dims: Sequence[int] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32 dtype: DType = jnp.float32
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -856,11 +966,13 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -856,11 +966,13 @@ class LayerNormMLP(TransformerEngineBase):
use_fused_ln_geglu_mlp = fuse_layernorm \ use_fused_ln_geglu_mlp = fuse_layernorm \
and (not self.use_bias) and is_geglu(self.activations) \ and (not self.use_bias) and is_geglu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3) and (self.intermediate_dropout_rate < 1e-3) \
and not self.enable_low_rank_adaptation
use_fused_ln_gelu_mlp = fuse_layernorm \ use_fused_ln_gelu_mlp = fuse_layernorm \
and self.use_bias and is_gelu(self.activations) \ and self.use_bias and is_gelu(self.activations) \
and (self.intermediate_dropout_rate < 1e-3) and (self.intermediate_dropout_rate < 1e-3) \
and not self.enable_low_rank_adaptation
# LayerNorm # LayerNorm
if self.enable_layernorm: if self.enable_layernorm:
...@@ -999,6 +1111,37 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -999,6 +1111,37 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_pkg=gemm1_fp8_meta_package, fp8_meta_pkg=gemm1_fp8_meta_package,
contracting_dims=(axis, contract_ind)) contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
wi_lora_a_kernel_shape = (*kernel_1_shape[:len(axis)], num_activations,
self.low_rank_adaptation_dim)
wi_lora_a_kernel_init_shape = (kernel_1_each_shape[0], num_activations,
self.low_rank_adaptation_dim)
wi_lora_a_kernel_init_each_shape = (kernel_1_each_shape[0],
self.low_rank_adaptation_dim)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_init_shape)
wi_lora_a_kernel = nn_partitioning.param_with_axes('wi_lora_a_kernel',
kernel_1_init,
num_activations,
-2,
wi_lora_a_kernel_init_each_shape,
jnp.float32,
axes=wi_lora_a_kernel_axes)
wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape)
wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype)
wi_lora_b_kernel_shape = (num_activations, self.low_rank_adaptation_dim,
self.intermediate_dim)
wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
wi_lora_b_kernel = nn_partitioning.param_with_axes('wi_lora_b_kernel',
nn.initializers.zeros,
wi_lora_b_kernel_shape,
jnp.float32,
axes=wi_lora_b_kernel_axes)
wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype)
x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel,
wi_lora_b_kernel, self.low_rank_adaptation_alpha)
bias = None bias = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('wi_bias', bias = nn_partitioning.param_with_axes('wi_bias',
...@@ -1042,6 +1185,28 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1042,6 +1185,28 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_pkg=gemm2_fp8_meta_package, fp8_meta_pkg=gemm2_fp8_meta_package,
contracting_dims=(axis, contract_ind)) contracting_dims=(axis, contract_ind))
if self.enable_low_rank_adaptation:
wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
wo_lora_a_kernel = nn_partitioning.param_with_axes('wo_lora_a_kernel',
self.kernel_init,
wo_lora_a_kernel_shape,
jnp.float32,
axes=wo_lora_a_kernel_axes)
wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype)
wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
wo_lora_b_kernel = nn_partitioning.param_with_axes('wo_lora_b_kernel',
nn.initializers.zeros,
wo_lora_b_kernel_shape,
jnp.float32,
axes=wo_lora_b_kernel_axes)
wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype)
out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel,
wo_lora_b_kernel, self.low_rank_adaptation_alpha)
bias = None bias = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes('wo_bias', bias = nn_partitioning.param_with_axes('wo_bias',
......
...@@ -637,6 +637,53 @@ def rotary_pos_emb(x: Array, ...@@ -637,6 +637,53 @@ def rotary_pos_emb(x: Array,
return consecutive_impl() return consecutive_impl()
class LoRAScope: # pylint: disable=too-few-public-methods
"""LoRA Scope"""
def __init__(self, qkv_proj=False, output_proj=False, mlp=False):
self.qkv_proj = qkv_proj
self.output_proj = output_proj
self.mlp = mlp
def __eq__(self, other):
return (self.qkv_proj, self.output_proj, self.mlp) == \
(other.qkv_proj, other.output_proj, other.mlp)
def _canonicalize_lora_scope(scope):
SCOPE_NONE = 'none'
SCOPE_ALL = 'all'
SCOPE_QKV_PROJ = 'qkv_proj'
SCOPE_OUTPUT_PROJ = 'output_proj'
SCOPE_MLP = 'mlp'
SCOPE_EX_QKV_PROJ = 'exclude_qkv_proj'
SCOPE_EX_OUTPUT_PROJ = 'exclude_output_proj'
SCOPE_EX_MLP = 'exclude_mlp'
scope = SCOPE_NONE if scope is None else scope
scope = scope.lower()
assert scope in [
SCOPE_NONE, SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_OUTPUT_PROJ, SCOPE_MLP, SCOPE_EX_QKV_PROJ,
SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP
]
lora_scope = LoRAScope()
if scope in [SCOPE_ALL, SCOPE_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP]:
lora_scope.qkv_proj = True
if scope in [SCOPE_ALL, SCOPE_OUTPUT_PROJ, SCOPE_EX_QKV_PROJ, SCOPE_EX_MLP]:
lora_scope.output_proj = True
if scope in [SCOPE_ALL, SCOPE_MLP, SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ]:
lora_scope.mlp = True
return lora_scope
class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
r""" r"""
Multi-head Attention (MHA), including Query, Multi-head Attention (MHA), including Query,
...@@ -723,6 +770,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -723,6 +770,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Indicate the method to coupled the coordinates. It should be one of Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
enable_sequence_parallel: bool, default = False enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot. Whether to enable sequence parallelism to operations except dot.
num_heads: int, default = None num_heads: int, default = None
...@@ -777,6 +833,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -777,6 +833,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive' rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -914,6 +973,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -914,6 +973,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp) inputs_q = with_sharding_constraint_by_logical_axes(inputs_q, inputs_logical_axes_maybe_sp)
lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)
if self.fuse_qkv_params: if self.fuse_qkv_params:
if is_qkvpack: if is_qkvpack:
qkv_proj, ln_out = LayerNormDenseGeneral( qkv_proj, ln_out = LayerNormDenseGeneral(
...@@ -932,6 +993,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -932,6 +993,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=(W_JOINED_AXES, W_TP_AXES), bias_axes=(W_JOINED_AXES, W_TP_AXES),
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
dot_input_axes=inputs_logical_axes_no_sp, dot_input_axes=inputs_logical_axes_no_sp,
name='qkv', name='qkv',
...@@ -954,6 +1018,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -954,6 +1018,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=(W_TP_AXES,), bias_axes=(W_TP_AXES,),
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
...@@ -972,6 +1039,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -972,6 +1039,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=(W_JOINED_AXES, W_TP_AXES), bias_axes=(W_JOINED_AXES, W_TP_AXES),
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
name='kv', name='kv',
dtype=self.dtype)(inputs_kv) dtype=self.dtype)(inputs_kv)
kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj') kv_proj = checkpoint_name(kv_proj, 'combined_kv_proj')
...@@ -986,6 +1056,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -986,6 +1056,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=(W_TP_AXES,), bias_axes=(W_TP_AXES,),
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype) dtype=self.dtype)
query, ln_out = LayerNormDenseGeneral( query, ln_out = LayerNormDenseGeneral(
enable_layernorm=self.input_layernorm, enable_layernorm=self.input_layernorm,
...@@ -1002,6 +1075,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1002,6 +1075,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=(W_TP_AXES,), bias_axes=(W_TP_AXES,),
enable_low_rank_adaptation=lora_scope.qkv_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
kernel_init=query_init, kernel_init=query_init,
layernorm_input_axes=inputs_logical_axes_maybe_sp, layernorm_input_axes=inputs_logical_axes_maybe_sp,
...@@ -1142,6 +1218,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1142,6 +1218,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes=(W_NO_SHARD_AXES,), bias_axes=(W_NO_SHARD_AXES,),
enable_low_rank_adaptation=lora_scope.output_proj,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
dtype=self.dtype, dtype=self.dtype,
name='out')(x) name='out')(x)
out = checkpoint_name(out, 'out_proj') out = checkpoint_name(out, 'out_proj')
...@@ -1379,6 +1458,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1379,6 +1458,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Indicate the method to coupled the coordinates. It should be one of Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2` ['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`. , d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
'exclude_output_proj', 'exclude_mlp']
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
enable_sequence_parallel: bool, default = False enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot. Whether to enable sequence parallelism to operations except dot.
...@@ -1434,6 +1523,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1434,6 +1523,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive' rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
drop_path: float = 0.0 drop_path: float = 0.0
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
...@@ -1579,6 +1671,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1579,6 +1671,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb=self.enable_rotary_pos_emb, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
fuse_qkv_params=self.fuse_qkv_params, fuse_qkv_params=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init, kernel_init=self.mha_kernel_init,
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -1646,6 +1741,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1646,6 +1741,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb=self.enable_rotary_pos_emb, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
float32_logits=self.float32_attention_logits, float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
...@@ -1674,6 +1772,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1674,6 +1772,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_input = with_sharding_constraint_by_logical_axes( mlp_input = with_sharding_constraint_by_logical_axes(
mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES)) mlp_input, (*generate_batch_seqlen_logical_axes(), HIDDEN_AXES))
lora_scope = _canonicalize_lora_scope(self.low_rank_adaptation_scope)
# MlpBlock # MlpBlock
residual = mlp_input residual = mlp_input
z, ln_out = LayerNormMLP( z, ln_out = LayerNormMLP(
...@@ -1697,6 +1797,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods ...@@ -1697,6 +1797,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init=self.bias_init, bias_init=self.bias_init,
bias_axes_1=(W_JOINED_AXES, W_TP_AXES), bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
bias_axes_2=(W_NO_SHARD_AXES,), bias_axes_2=(W_NO_SHARD_AXES,),
enable_low_rank_adaptation=lora_scope.mlp,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES),
dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES),
dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES),
......
...@@ -131,6 +131,9 @@ class Linear(TransformerEngineBaseLayer): ...@@ -131,6 +131,9 @@ class Linear(TransformerEngineBaseLayer):
use_bias: bool = True use_bias: bool = True
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = () bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
sharding_type: ShardingType = ShardingType.SINGLE sharding_type: ShardingType = ShardingType.SINGLE
...@@ -147,6 +150,9 @@ class Linear(TransformerEngineBaseLayer): ...@@ -147,6 +150,9 @@ class Linear(TransformerEngineBaseLayer):
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes, bias_axes=self.bias_axes,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
axis=self.axis, axis=self.axis,
dtype=self.dtype, dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence) transpose_batch_sequence=self.transpose_batch_sequence)
...@@ -174,6 +180,9 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -174,6 +180,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
use_bias: bool = False use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes: Tuple[str, ...] = () bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True return_layernorm_output: bool = True
axis: Union[Iterable[int], int] = -1 axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
...@@ -201,6 +210,9 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -201,6 +210,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
use_bias=self.use_bias, use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes, bias_axes=self.bias_axes,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
return_layernorm_output=self.return_layernorm_output, return_layernorm_output=self.return_layernorm_output,
axis=self.axis, axis=self.axis,
dtype=self.dtype, dtype=self.dtype,
...@@ -232,6 +244,9 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -232,6 +244,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = WeightInit.Constant(0.0)
bias_axes_1: Tuple[str, ...] = () bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = () bias_axes_2: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ('relu',) activations: Sequence[Union[str, Callable]] = ('relu',)
intermediate_dropout_rate: float = 0.1 intermediate_dropout_rate: float = 0.1
...@@ -263,6 +278,9 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -263,6 +278,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes_1=self.bias_axes_1, bias_axes_1=self.bias_axes_1,
bias_axes_2=self.bias_axes_2, bias_axes_2=self.bias_axes_2,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
return_layernorm_output=self.return_layernorm_output, return_layernorm_output=self.return_layernorm_output,
activations=self.activations, activations=self.activations,
intermediate_dropout_rate=self.intermediate_dropout_rate, intermediate_dropout_rate=self.intermediate_dropout_rate,
......
...@@ -137,6 +137,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -137,6 +137,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive' rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
fuse_qkv_params: bool = True fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False enable_sequence_parallel: bool = False
...@@ -208,6 +211,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -208,6 +211,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
enable_rotary_pos_emb=self.enable_rotary_pos_emb, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
fuse_qkv_params=self.fuse_qkv_params, fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel, enable_sequence_parallel=self.enable_sequence_parallel,
...@@ -262,6 +268,9 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -262,6 +268,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000) rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive' rotary_pos_emb_group_method: str = 'consecutive'
low_rank_adaptation_scope: str = 'none'
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0 drop_path: float = 0.0
...@@ -332,6 +341,9 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -332,6 +341,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_rotary_pos_emb=self.enable_rotary_pos_emb, enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows, rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method, rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
enable_relative_embedding=self.enable_relative_embedding, enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module, relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path, drop_path=self.drop_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment