Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
7c1828f8
Unverified
Commit
7c1828f8
authored
Apr 16, 2024
by
Ming-Xu Huang
Committed by
GitHub
Apr 16, 2024
Browse files
Support Low Rank Adaptation (LoRA). (#745)
parent
1442b47e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
412 additions
and
2 deletions
+412
-2
tests/jax/test_functions.py
tests/jax/test_functions.py
+68
-0
tests/jax/test_praxis_layers.py
tests/jax/test_praxis_layers.py
+44
-0
transformer_engine/jax/flax/module.py
transformer_engine/jax/flax/module.py
+167
-2
transformer_engine/jax/flax/transformer.py
transformer_engine/jax/flax/transformer.py
+103
-0
transformer_engine/jax/praxis/module.py
transformer_engine/jax/praxis/module.py
+18
-0
transformer_engine/jax/praxis/transformer.py
transformer_engine/jax/praxis/transformer.py
+12
-0
No files found.
tests/jax/test_functions.py
0 → 100644
View file @
7c1828f8
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
jax
import
jax.numpy
as
jnp
from
utils
import
assert_allclose
from
transformer_engine.jax.flax.module
import
_apply_low_rank_adaptation
from
transformer_engine.jax.flax.module
import
_normalize_axes
from
transformer_engine.jax.flax.transformer
import
LoRAScope
from
transformer_engine.jax.flax.transformer
import
_canonicalize_lora_scope
class
TestLoRA
:
def
reference
(
x
,
la
,
lb
,
pattern
,
scale
):
out
=
jnp
.
einsum
(
pattern
,
x
,
la
,
lb
)
return
out
*
scale
@
pytest
.
mark
.
parametrize
(
'shape'
,
[(
32
,
1024
),
(
32
,
128
,
1024
)])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
jnp
.
float32
,
jnp
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'axis_features_pattern'
,
[((
-
1
,),
(
1024
,),
'...h,hr,rk->...k'
),
((
-
1
,),
(
3
,
1024
),
'...h,hkr,krz->...kz'
)])
@
pytest
.
mark
.
parametrize
(
'rank'
,
[
32
,
16
])
@
pytest
.
mark
.
parametrize
(
'alpha'
,
[
None
,
4
,
8
])
def
test_lora
(
self
,
shape
,
dtype
,
axis_features_pattern
,
rank
,
alpha
):
axis
,
features
,
pattern
=
axis_features_pattern
axis
=
_normalize_axes
(
axis
,
len
(
shape
))
shape_in_axis
=
tuple
(
shape
[
ax
]
for
ax
in
axis
)
key
=
jax
.
random
.
key
(
1124
)
key
,
x_key
=
jax
.
random
.
split
(
key
)
x
=
jax
.
random
.
normal
(
x_key
,
shape
,
dtype
)
key
,
la_key
=
jax
.
random
.
split
(
key
)
la_shape
=
(
*
shape_in_axis
,
*
features
[:
-
1
],
rank
)
la
=
jax
.
random
.
normal
(
la_key
,
la_shape
,
dtype
)
key
,
lb_key
=
jax
.
random
.
split
(
key
)
lb_shape
=
(
*
features
[:
-
1
],
rank
,
features
[
-
1
])
lb
=
jax
.
random
.
normal
(
lb_key
,
lb_shape
,
dtype
)
out_target
=
_apply_low_rank_adaptation
(
x
,
axis
,
features
,
la
,
lb
,
alpha
)
scale_ref
=
alpha
/
rank
if
alpha
is
not
None
else
1.0
out_ref
=
TestLoRA
.
reference
(
x
,
la
,
lb
,
pattern
,
scale_ref
)
assert_allclose
(
out_target
,
out_ref
,
dtype
=
dtype
)
@
pytest
.
mark
.
parametrize
(
'scope_ref_assert'
,
[(
'none'
,
LoRAScope
(
False
,
False
,
False
),
False
),
(
'all'
,
LoRAScope
(
True
,
True
,
True
),
False
),
(
'qkv_proj'
,
LoRAScope
(
True
,
False
,
False
),
False
),
(
'output_proj'
,
LoRAScope
(
False
,
True
,
False
),
False
),
(
'mlp'
,
LoRAScope
(
False
,
False
,
True
),
False
),
(
'exclude_qkv_proj'
,
LoRAScope
(
False
,
True
,
True
),
False
),
(
'exclude_output_proj'
,
LoRAScope
(
True
,
False
,
True
),
False
),
(
'exclude_mlp'
,
LoRAScope
(
True
,
True
,
False
),
False
),
(
'messing_up'
,
LoRAScope
(),
True
)])
def
test_lora_scope_generator
(
self
,
scope_ref_assert
):
scope
,
reference
,
need_assert
=
scope_ref_assert
try
:
lora_scope
=
_canonicalize_lora_scope
(
scope
)
assert
lora_scope
==
reference
except
AssertionError
as
ae
:
assert
need_assert
,
f
"
{
ae
.
args
}
"
tests/jax/test_praxis_layers.py
View file @
7c1828f8
...
...
@@ -784,6 +784,7 @@ class MultiHeadAttnAttr:
NUM_GQA_GROUPS
=
'num_gqa_groups'
ENABLE_ROPE
=
'enable_rotary_pos_emb'
ROPE_GROUP_METHOD
=
'rotary_pos_emb_group_method'
LORA_SCOPE
=
'low_rank_adaptation_scope'
ATTRS
=
[{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
...
...
@@ -853,6 +854,22 @@ class MultiHeadAttnAttr:
NUM_ATTN_HEADS
:
8
,
NUM_GQA_GROUPS
:
4
,
ATTN_MASK_TYPE
:
'causal'
},
{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
'consecutive'
,
ATTN_MASK_TYPE
:
'padding'
,
LORA_SCOPE
:
'all'
},
{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
'consecutive'
,
ATTN_MASK_TYPE
:
'causal'
,
LORA_SCOPE
:
'all'
}]
...
...
@@ -883,6 +900,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type
=
attrs
[
MultiHeadAttnAttr
.
ATTN_MASK_TYPE
]
enable_rotary_pos_emb
=
attrs
[
MultiHeadAttnAttr
.
ENABLE_ROPE
]
rotary_pos_emb_group_method
=
attrs
[
MultiHeadAttnAttr
.
ROPE_GROUP_METHOD
]
low_rank_adaptation_scope
=
attrs
.
get
(
MultiHeadAttnAttr
.
LORA_SCOPE
,
'none'
)
fuse_qkv_params
=
True
transpose_batch_sequence
=
True
scale_attn_logits
=
False
...
...
@@ -905,6 +923,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type
=
attn_mask_type
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
fuse_qkv_params
=
fuse_qkv_params
,
transpose_batch_sequence
=
transpose_batch_sequence
,
scale_attn_logits
=
scale_attn_logits
,
...
...
@@ -926,6 +945,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type
=
attn_mask_type
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
fuse_qkv_params
=
fuse_qkv_params
,
transpose_batch_sequence
=
transpose_batch_sequence
,
scale_attn_logits
=
scale_attn_logits
,
...
...
@@ -969,6 +989,7 @@ class TransformerLayerAttr:
TRANSPOSE_BS
=
'transpose_batch_sequence'
ENABLE_ROPE
=
'enable_rotary_pos_emb'
ROPE_GROUP_METHOD
=
'rotary_pos_emb_group_method'
LORA_SCOPE
=
'low_rank_adaptation_scope'
ATTRS
=
[{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
...
...
@@ -1113,6 +1134,16 @@ class TransformerLayerAttr:
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
'consecutive'
,
TRANSPOSE_BS
:
False
},
{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
'gelu'
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
'consecutive'
,
TRANSPOSE_BS
:
False
,
LORA_SCOPE
:
'all'
},
{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
...
...
@@ -1185,6 +1216,16 @@ class TransformerLayerAttr:
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
'consecutive'
,
TRANSPOSE_BS
:
False
},
{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
'gelu'
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
'consecutive'
,
TRANSPOSE_BS
:
False
,
LORA_SCOPE
:
'all'
}]
...
...
@@ -1219,6 +1260,7 @@ class TestTransformer(TestLayer):
layer_type
=
attrs
[
TransformerLayerAttr
.
LYR_TYPE
]
enable_rotary_pos_emb
=
attrs
[
TransformerLayerAttr
.
ENABLE_ROPE
]
rotary_pos_emb_group_method
=
attrs
[
TransformerLayerAttr
.
ROPE_GROUP_METHOD
]
low_rank_adaptation_scope
=
attrs
.
get
(
TransformerLayerAttr
.
LORA_SCOPE
,
'none'
)
enable_relative_embedding
=
True
relative_embedding
=
pax_fiddle
.
Config
(
RelativePositionBiases
,
dtype
=
dtype
,
...
...
@@ -1257,6 +1299,7 @@ class TestTransformer(TestLayer):
enable_relative_embedding
=
enable_relative_embedding
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
relative_embedding
=
relative_embedding
,
drop_path
=
drop_path
,
transpose_batch_sequence
=
transpose_batch_sequence
)
...
...
@@ -1282,6 +1325,7 @@ class TestTransformer(TestLayer):
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
enable_relative_embedding
=
enable_relative_embedding
,
relative_embedding
=
relative_embedding_flax_module
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
drop_path
=
drop_path
,
transpose_batch_sequence
=
transpose_batch_sequence
)
...
...
transformer_engine/jax/flax/module.py
View file @
7c1828f8
...
...
@@ -104,6 +104,31 @@ def _combine_biases(*masks: List[Array]):
return
mask
def
_apply_low_rank_adaptation
(
x
,
axis
,
features
,
lora_a_kernel
,
lora_b_kernel
,
alpha
):
"""Low Rank Adaptation Implementation"""
assert
len
(
axis
)
<=
5
hidden_in_names
=
'ijklm'
[:
len
(
axis
)]
assert
len
(
features
)
<=
5
hidden_out_names
=
'nopqr'
[:
len
(
features
)]
rank_name
=
's'
assert
lora_a_kernel
.
shape
[
-
1
]
==
lora_b_kernel
.
shape
[
-
2
]
rank
=
lora_a_kernel
.
shape
[
-
1
]
scaling
=
alpha
/
rank
if
alpha
is
not
None
else
1.0
x_einsum_express
=
f
"...
{
hidden_in_names
}
"
lora_a_einsum_express
=
f
"
{
hidden_in_names
}{
hidden_out_names
[:
-
1
]
}{
rank_name
}
"
lora_b_einsum_express
=
f
"
{
hidden_out_names
[:
-
1
]
}{
rank_name
}{
hidden_out_names
[
-
1
]
}
"
output_einsum_express
=
f
"...
{
hidden_out_names
}
"
final_einsum_express
=
f
"
{
x_einsum_express
}
,
{
lora_a_einsum_express
}
,
{
lora_b_einsum_express
}
"
\
f
"->
{
output_einsum_express
}
"
output
=
jnp
.
einsum
(
final_einsum_express
,
x
,
lora_a_kernel
,
lora_b_kernel
)
output
=
output
*
scaling
return
output
class
Softmax
(
nn
.
Module
):
# pylint: disable=too-few-public-methods
r
"""
Applies softmax over a mini-batch of inputs.
...
...
@@ -355,6 +380,14 @@ class DenseGeneral(TransformerEngineBase):
bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`
\f
rac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
...
...
@@ -374,6 +407,9 @@ class DenseGeneral(TransformerEngineBase):
use_bias
:
bool
=
True
bias_init
:
Initializer
=
nn
.
initializers
.
zeros
bias_axes
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
dtype
:
DType
=
jnp
.
float32
transpose_batch_sequence
:
bool
=
False
...
...
@@ -439,6 +475,32 @@ class DenseGeneral(TransformerEngineBase):
fp8_meta_pkg
=
fp8_gemm_pkg
,
contracting_dims
=
(
axis
,
contract_ind
))
if
self
.
enable_low_rank_adaptation
:
lora_a_kernel_shape
=
(
*
kernel_shape
[:
len
(
axis
)],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
)
lora_a_kernel_init_shape
=
(
kernel_param_shape
[
0
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_init_shape
)
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
'lora_a_kernel'
,
self
.
kernel_init
,
lora_a_kernel_init_shape
,
jnp
.
float32
,
axes
=
lora_a_kernel_axes
)
lora_a_kernel
=
jnp
.
reshape
(
lora_a_kernel
,
lora_a_kernel_shape
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
self
.
dtype
)
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
lora_b_kernel_axes
=
(
None
,)
*
len
(
lora_b_kernel_shape
)
lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
'lora_b_kernel'
,
nn
.
initializers
.
zeros
,
lora_b_kernel_shape
,
jnp
.
float32
,
axes
=
lora_b_kernel_axes
)
lora_b_kernel
=
lora_b_kernel
.
astype
(
self
.
dtype
)
y
+=
_apply_low_rank_adaptation
(
inputs
,
axis
,
features
,
lora_a_kernel
,
lora_b_kernel
,
self
.
low_rank_adaptation_alpha
)
if
bias
is
not
None
:
bias_shape
=
(
1
,)
*
(
y
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
y
+=
jnp
.
reshape
(
bias
,
bias_shape
)
...
...
@@ -502,6 +564,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
return_layernorm_output: bool, default = True
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
...
...
@@ -541,6 +611,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias_init
:
Initializer
=
nn
.
initializers
.
zeros
bias_axes
:
Tuple
[
str
,
...]
=
()
return_layernorm_output
:
bool
=
True
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
dtype
:
DType
=
jnp
.
float32
transpose_batch_sequence
:
bool
=
True
...
...
@@ -650,6 +723,32 @@ class LayerNormDenseGeneral(TransformerEngineBase):
fp8_meta_pkg
=
fp8_meta_package
,
contracting_dims
=
(
axis
,
contract_ind
))
if
self
.
enable_low_rank_adaptation
:
lora_a_kernel_shape
=
(
*
kernel_shape
[:
len
(
axis
)],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
)
lora_a_kernel_init_shape
=
(
kernel_param_shape
[
0
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_init_shape
)
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
'lora_a_kernel'
,
self
.
kernel_init
,
lora_a_kernel_init_shape
,
jnp
.
float32
,
axes
=
lora_a_kernel_axes
)
lora_a_kernel
=
jnp
.
reshape
(
lora_a_kernel
,
lora_a_kernel_shape
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
self
.
dtype
)
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
lora_b_kernel_axes
=
(
None
,)
*
len
(
lora_b_kernel_shape
)
lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
'lora_b_kernel'
,
nn
.
initializers
.
zeros
,
lora_b_kernel_shape
,
jnp
.
float32
,
axes
=
lora_b_kernel_axes
)
lora_b_kernel
=
lora_b_kernel
.
astype
(
self
.
dtype
)
z
+=
_apply_low_rank_adaptation
(
y
,
axis
,
features
,
lora_a_kernel
,
lora_b_kernel
,
self
.
low_rank_adaptation_alpha
)
bias
=
None
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
'bias'
,
...
...
@@ -745,6 +844,14 @@ class LayerNormMLP(TransformerEngineBase):
Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`.
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
...
...
@@ -791,6 +898,9 @@ class LayerNormMLP(TransformerEngineBase):
intermediate_dropout_rng_name
:
str
=
'dropout'
intermediate_dropout_rate
:
float
=
0.1
intermediate_hidden_dropout_dims
:
Sequence
[
int
]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
dtype
:
DType
=
jnp
.
float32
transpose_batch_sequence
:
bool
=
True
...
...
@@ -856,11 +966,13 @@ class LayerNormMLP(TransformerEngineBase):
use_fused_ln_geglu_mlp
=
fuse_layernorm
\
and
(
not
self
.
use_bias
)
and
is_geglu
(
self
.
activations
)
\
and
(
self
.
intermediate_dropout_rate
<
1e-3
)
and
(
self
.
intermediate_dropout_rate
<
1e-3
)
\
and
not
self
.
enable_low_rank_adaptation
use_fused_ln_gelu_mlp
=
fuse_layernorm
\
and
self
.
use_bias
and
is_gelu
(
self
.
activations
)
\
and
(
self
.
intermediate_dropout_rate
<
1e-3
)
and
(
self
.
intermediate_dropout_rate
<
1e-3
)
\
and
not
self
.
enable_low_rank_adaptation
# LayerNorm
if
self
.
enable_layernorm
:
...
...
@@ -999,6 +1111,37 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_pkg
=
gemm1_fp8_meta_package
,
contracting_dims
=
(
axis
,
contract_ind
))
if
self
.
enable_low_rank_adaptation
:
wi_lora_a_kernel_shape
=
(
*
kernel_1_shape
[:
len
(
axis
)],
num_activations
,
self
.
low_rank_adaptation_dim
)
wi_lora_a_kernel_init_shape
=
(
kernel_1_each_shape
[
0
],
num_activations
,
self
.
low_rank_adaptation_dim
)
wi_lora_a_kernel_init_each_shape
=
(
kernel_1_each_shape
[
0
],
self
.
low_rank_adaptation_dim
)
wi_lora_a_kernel_axes
=
(
None
,)
*
len
(
wi_lora_a_kernel_init_shape
)
wi_lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
'wi_lora_a_kernel'
,
kernel_1_init
,
num_activations
,
-
2
,
wi_lora_a_kernel_init_each_shape
,
jnp
.
float32
,
axes
=
wi_lora_a_kernel_axes
)
wi_lora_a_kernel
=
jnp
.
reshape
(
wi_lora_a_kernel
,
wi_lora_a_kernel_shape
)
wi_lora_a_kernel
=
wi_lora_a_kernel
.
astype
(
self
.
dtype
)
wi_lora_b_kernel_shape
=
(
num_activations
,
self
.
low_rank_adaptation_dim
,
self
.
intermediate_dim
)
wi_lora_b_kernel_axes
=
(
None
,)
*
len
(
wi_lora_b_kernel_shape
)
wi_lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
'wi_lora_b_kernel'
,
nn
.
initializers
.
zeros
,
wi_lora_b_kernel_shape
,
jnp
.
float32
,
axes
=
wi_lora_b_kernel_axes
)
wi_lora_b_kernel
=
wi_lora_b_kernel
.
astype
(
self
.
dtype
)
x
+=
_apply_low_rank_adaptation
(
y
,
axis
,
intermediate_dim
,
wi_lora_a_kernel
,
wi_lora_b_kernel
,
self
.
low_rank_adaptation_alpha
)
bias
=
None
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
'wi_bias'
,
...
...
@@ -1042,6 +1185,28 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_pkg
=
gemm2_fp8_meta_package
,
contracting_dims
=
(
axis
,
contract_ind
))
if
self
.
enable_low_rank_adaptation
:
wo_lora_a_kernel_shape
=
(
self
.
intermediate_dim
,
self
.
low_rank_adaptation_dim
)
wo_lora_a_kernel_axes
=
(
None
,)
*
len
(
wo_lora_a_kernel_shape
)
wo_lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
'wo_lora_a_kernel'
,
self
.
kernel_init
,
wo_lora_a_kernel_shape
,
jnp
.
float32
,
axes
=
wo_lora_a_kernel_axes
)
wo_lora_a_kernel
=
wo_lora_a_kernel
.
astype
(
self
.
dtype
)
wo_lora_b_kernel_shape
=
(
self
.
low_rank_adaptation_dim
,
hidden_size
)
wo_lora_b_kernel_axes
=
(
None
,)
*
len
(
wo_lora_b_kernel_shape
)
wo_lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
'wo_lora_b_kernel'
,
nn
.
initializers
.
zeros
,
wo_lora_b_kernel_shape
,
jnp
.
float32
,
axes
=
wo_lora_b_kernel_axes
)
wo_lora_b_kernel
=
wo_lora_b_kernel
.
astype
(
self
.
dtype
)
out
+=
_apply_low_rank_adaptation
(
z
,
axis
,
hidden_size_tuple
,
wo_lora_a_kernel
,
wo_lora_b_kernel
,
self
.
low_rank_adaptation_alpha
)
bias
=
None
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
'wo_bias'
,
...
...
transformer_engine/jax/flax/transformer.py
View file @
7c1828f8
...
...
@@ -637,6 +637,53 @@ def rotary_pos_emb(x: Array,
return
consecutive_impl
()
class
LoRAScope
:
# pylint: disable=too-few-public-methods
"""LoRA Scope"""
def
__init__
(
self
,
qkv_proj
=
False
,
output_proj
=
False
,
mlp
=
False
):
self
.
qkv_proj
=
qkv_proj
self
.
output_proj
=
output_proj
self
.
mlp
=
mlp
def
__eq__
(
self
,
other
):
return
(
self
.
qkv_proj
,
self
.
output_proj
,
self
.
mlp
)
==
\
(
other
.
qkv_proj
,
other
.
output_proj
,
other
.
mlp
)
def
_canonicalize_lora_scope
(
scope
):
SCOPE_NONE
=
'none'
SCOPE_ALL
=
'all'
SCOPE_QKV_PROJ
=
'qkv_proj'
SCOPE_OUTPUT_PROJ
=
'output_proj'
SCOPE_MLP
=
'mlp'
SCOPE_EX_QKV_PROJ
=
'exclude_qkv_proj'
SCOPE_EX_OUTPUT_PROJ
=
'exclude_output_proj'
SCOPE_EX_MLP
=
'exclude_mlp'
scope
=
SCOPE_NONE
if
scope
is
None
else
scope
scope
=
scope
.
lower
()
assert
scope
in
[
SCOPE_NONE
,
SCOPE_ALL
,
SCOPE_QKV_PROJ
,
SCOPE_OUTPUT_PROJ
,
SCOPE_MLP
,
SCOPE_EX_QKV_PROJ
,
SCOPE_EX_OUTPUT_PROJ
,
SCOPE_EX_MLP
]
lora_scope
=
LoRAScope
()
if
scope
in
[
SCOPE_ALL
,
SCOPE_QKV_PROJ
,
SCOPE_EX_OUTPUT_PROJ
,
SCOPE_EX_MLP
]:
lora_scope
.
qkv_proj
=
True
if
scope
in
[
SCOPE_ALL
,
SCOPE_OUTPUT_PROJ
,
SCOPE_EX_QKV_PROJ
,
SCOPE_EX_MLP
]:
lora_scope
.
output_proj
=
True
if
scope
in
[
SCOPE_ALL
,
SCOPE_MLP
,
SCOPE_EX_QKV_PROJ
,
SCOPE_EX_OUTPUT_PROJ
]:
lora_scope
.
mlp
=
True
return
lora_scope
class
MultiHeadAttention
(
nn
.
Module
):
# pylint: disable=too-few-public-methods
r
"""
Multi-head Attention (MHA), including Query,
...
...
@@ -723,6 +770,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
num_heads: int, default = None
...
...
@@ -777,6 +833,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
'consecutive'
low_rank_adaptation_scope
:
str
=
'none'
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
dtype
:
DType
=
jnp
.
float32
fuse_qkv_params
:
bool
=
True
transpose_batch_sequence
:
bool
=
True
...
...
@@ -914,6 +973,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
inputs_q
=
with_sharding_constraint_by_logical_axes
(
inputs_q
,
inputs_logical_axes_maybe_sp
)
lora_scope
=
_canonicalize_lora_scope
(
self
.
low_rank_adaptation_scope
)
if
self
.
fuse_qkv_params
:
if
is_qkvpack
:
qkv_proj
,
ln_out
=
LayerNormDenseGeneral
(
...
...
@@ -932,6 +993,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_JOINED_AXES
,
W_TP_AXES
),
enable_low_rank_adaptation
=
lora_scope
.
qkv_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
layernorm_input_axes
=
inputs_logical_axes_maybe_sp
,
dot_input_axes
=
inputs_logical_axes_no_sp
,
name
=
'qkv'
,
...
...
@@ -954,6 +1018,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_TP_AXES
,),
enable_low_rank_adaptation
=
lora_scope
.
qkv_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
dtype
=
self
.
dtype
,
kernel_init
=
query_init
,
layernorm_input_axes
=
inputs_logical_axes_maybe_sp
,
...
...
@@ -972,6 +1039,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_JOINED_AXES
,
W_TP_AXES
),
enable_low_rank_adaptation
=
lora_scope
.
qkv_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
name
=
'kv'
,
dtype
=
self
.
dtype
)(
inputs_kv
)
kv_proj
=
checkpoint_name
(
kv_proj
,
'combined_kv_proj'
)
...
...
@@ -986,6 +1056,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_TP_AXES
,),
enable_low_rank_adaptation
=
lora_scope
.
qkv_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
dtype
=
self
.
dtype
)
query
,
ln_out
=
LayerNormDenseGeneral
(
enable_layernorm
=
self
.
input_layernorm
,
...
...
@@ -1002,6 +1075,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_TP_AXES
,),
enable_low_rank_adaptation
=
lora_scope
.
qkv_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
dtype
=
self
.
dtype
,
kernel_init
=
query_init
,
layernorm_input_axes
=
inputs_logical_axes_maybe_sp
,
...
...
@@ -1142,6 +1218,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_NO_SHARD_AXES
,),
enable_low_rank_adaptation
=
lora_scope
.
output_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
dtype
=
self
.
dtype
,
name
=
'out'
)(
x
)
out
=
checkpoint_name
(
out
,
'out_proj'
)
...
...
@@ -1379,6 +1458,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
'exclude_output_proj', 'exclude_mlp']
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
...
...
@@ -1434,6 +1523,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
'consecutive'
low_rank_adaptation_scope
:
str
=
'none'
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
dtype
:
DType
=
jnp
.
float32
drop_path
:
float
=
0.0
fuse_qkv_params
:
bool
=
True
...
...
@@ -1579,6 +1671,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
fuse_qkv_params
=
self
.
fuse_qkv_params
,
kernel_init
=
self
.
mha_kernel_init
,
use_bias
=
self
.
use_bias
,
...
...
@@ -1646,6 +1741,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
float32_logits
=
self
.
float32_attention_logits
,
scale_attn_logits
=
self
.
scale_attn_logits
,
scaled_query_init
=
self
.
scaled_query_init
,
...
...
@@ -1674,6 +1772,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_input
=
with_sharding_constraint_by_logical_axes
(
mlp_input
,
(
*
generate_batch_seqlen_logical_axes
(),
HIDDEN_AXES
))
lora_scope
=
_canonicalize_lora_scope
(
self
.
low_rank_adaptation_scope
)
# MlpBlock
residual
=
mlp_input
z
,
ln_out
=
LayerNormMLP
(
...
...
@@ -1697,6 +1797,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init
=
self
.
bias_init
,
bias_axes_1
=
(
W_JOINED_AXES
,
W_TP_AXES
),
bias_axes_2
=
(
W_NO_SHARD_AXES
,),
enable_low_rank_adaptation
=
lora_scope
.
mlp
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
layernorm_input_axes
=
(
*
generate_batch_seqlen_logical_axes
(),
HIDDEN_AXES
),
dot_1_input_axes
=
(
*
generate_batch_seqlen_logical_axes
(
False
),
HIDDEN_AXES
),
dot_2_input_axes
=
(
*
generate_batch_seqlen_logical_axes
(
False
),
HIDDEN_TP_AXES
),
...
...
transformer_engine/jax/praxis/module.py
View file @
7c1828f8
...
...
@@ -131,6 +131,9 @@ class Linear(TransformerEngineBaseLayer):
use_bias
:
bool
=
True
bias_init
:
WeightInit
=
WeightInit
.
Constant
(
0.0
)
bias_axes
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
transpose_batch_sequence
:
bool
=
False
sharding_type
:
ShardingType
=
ShardingType
.
SINGLE
...
...
@@ -147,6 +150,9 @@ class Linear(TransformerEngineBaseLayer):
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes
=
self
.
bias_axes
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
axis
=
self
.
axis
,
dtype
=
self
.
dtype
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
)
...
...
@@ -174,6 +180,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
use_bias
:
bool
=
False
bias_init
:
WeightInit
=
WeightInit
.
Constant
(
0.0
)
bias_axes
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
return_layernorm_output
:
bool
=
True
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
transpose_batch_sequence
:
bool
=
False
...
...
@@ -201,6 +210,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes
=
self
.
bias_axes
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
return_layernorm_output
=
self
.
return_layernorm_output
,
axis
=
self
.
axis
,
dtype
=
self
.
dtype
,
...
...
@@ -232,6 +244,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
bias_init
:
WeightInit
=
WeightInit
.
Constant
(
0.0
)
bias_axes_1
:
Tuple
[
str
,
...]
=
()
bias_axes_2
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
return_layernorm_output
:
bool
=
True
activations
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
'relu'
,)
intermediate_dropout_rate
:
float
=
0.1
...
...
@@ -263,6 +278,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes_1
=
self
.
bias_axes_1
,
bias_axes_2
=
self
.
bias_axes_2
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
return_layernorm_output
=
self
.
return_layernorm_output
,
activations
=
self
.
activations
,
intermediate_dropout_rate
=
self
.
intermediate_dropout_rate
,
...
...
transformer_engine/jax/praxis/transformer.py
View file @
7c1828f8
...
...
@@ -137,6 +137,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
'consecutive'
low_rank_adaptation_scope
:
str
=
'none'
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
fuse_qkv_params
:
bool
=
True
transpose_batch_sequence
:
bool
=
True
enable_sequence_parallel
:
bool
=
False
...
...
@@ -208,6 +211,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
fuse_qkv_params
=
self
.
fuse_qkv_params
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
enable_sequence_parallel
=
self
.
enable_sequence_parallel
,
...
...
@@ -262,6 +268,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
'consecutive'
low_rank_adaptation_scope
:
str
=
'none'
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
enable_relative_embedding
:
bool
=
True
relative_embedding
:
pax_fiddle
.
Config
[
RelativePositionBiases
]
=
pax_fiddle
.
template_field
(
None
)
drop_path
:
float
=
0.0
...
...
@@ -332,6 +341,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
enable_relative_embedding
=
self
.
enable_relative_embedding
,
relative_embedding
=
relative_embedding_flax_module
,
drop_path
=
self
.
drop_path
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment