Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
7c1828f8
Unverified
Commit
7c1828f8
authored
Apr 16, 2024
by
Ming-Xu Huang
Committed by
GitHub
Apr 16, 2024
Browse files
Support Low Rank Adaptation (LoRA). (#745)
parent
1442b47e
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
412 additions
and
2 deletions
+412
-2
tests/jax/test_functions.py
tests/jax/test_functions.py
+68
-0
tests/jax/test_praxis_layers.py
tests/jax/test_praxis_layers.py
+44
-0
transformer_engine/jax/flax/module.py
transformer_engine/jax/flax/module.py
+167
-2
transformer_engine/jax/flax/transformer.py
transformer_engine/jax/flax/transformer.py
+103
-0
transformer_engine/jax/praxis/module.py
transformer_engine/jax/praxis/module.py
+18
-0
transformer_engine/jax/praxis/transformer.py
transformer_engine/jax/praxis/transformer.py
+12
-0
No files found.
tests/jax/test_functions.py
0 → 100644
View file @
7c1828f8
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
jax
import
jax.numpy
as
jnp
from
utils
import
assert_allclose
from
transformer_engine.jax.flax.module
import
_apply_low_rank_adaptation
from
transformer_engine.jax.flax.module
import
_normalize_axes
from
transformer_engine.jax.flax.transformer
import
LoRAScope
from
transformer_engine.jax.flax.transformer
import
_canonicalize_lora_scope
class
TestLoRA
:
def
reference
(
x
,
la
,
lb
,
pattern
,
scale
):
out
=
jnp
.
einsum
(
pattern
,
x
,
la
,
lb
)
return
out
*
scale
@
pytest
.
mark
.
parametrize
(
'shape'
,
[(
32
,
1024
),
(
32
,
128
,
1024
)])
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
jnp
.
float32
,
jnp
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
'axis_features_pattern'
,
[((
-
1
,),
(
1024
,),
'...h,hr,rk->...k'
),
((
-
1
,),
(
3
,
1024
),
'...h,hkr,krz->...kz'
)])
@
pytest
.
mark
.
parametrize
(
'rank'
,
[
32
,
16
])
@
pytest
.
mark
.
parametrize
(
'alpha'
,
[
None
,
4
,
8
])
def
test_lora
(
self
,
shape
,
dtype
,
axis_features_pattern
,
rank
,
alpha
):
axis
,
features
,
pattern
=
axis_features_pattern
axis
=
_normalize_axes
(
axis
,
len
(
shape
))
shape_in_axis
=
tuple
(
shape
[
ax
]
for
ax
in
axis
)
key
=
jax
.
random
.
key
(
1124
)
key
,
x_key
=
jax
.
random
.
split
(
key
)
x
=
jax
.
random
.
normal
(
x_key
,
shape
,
dtype
)
key
,
la_key
=
jax
.
random
.
split
(
key
)
la_shape
=
(
*
shape_in_axis
,
*
features
[:
-
1
],
rank
)
la
=
jax
.
random
.
normal
(
la_key
,
la_shape
,
dtype
)
key
,
lb_key
=
jax
.
random
.
split
(
key
)
lb_shape
=
(
*
features
[:
-
1
],
rank
,
features
[
-
1
])
lb
=
jax
.
random
.
normal
(
lb_key
,
lb_shape
,
dtype
)
out_target
=
_apply_low_rank_adaptation
(
x
,
axis
,
features
,
la
,
lb
,
alpha
)
scale_ref
=
alpha
/
rank
if
alpha
is
not
None
else
1.0
out_ref
=
TestLoRA
.
reference
(
x
,
la
,
lb
,
pattern
,
scale_ref
)
assert_allclose
(
out_target
,
out_ref
,
dtype
=
dtype
)
@
pytest
.
mark
.
parametrize
(
'scope_ref_assert'
,
[(
'none'
,
LoRAScope
(
False
,
False
,
False
),
False
),
(
'all'
,
LoRAScope
(
True
,
True
,
True
),
False
),
(
'qkv_proj'
,
LoRAScope
(
True
,
False
,
False
),
False
),
(
'output_proj'
,
LoRAScope
(
False
,
True
,
False
),
False
),
(
'mlp'
,
LoRAScope
(
False
,
False
,
True
),
False
),
(
'exclude_qkv_proj'
,
LoRAScope
(
False
,
True
,
True
),
False
),
(
'exclude_output_proj'
,
LoRAScope
(
True
,
False
,
True
),
False
),
(
'exclude_mlp'
,
LoRAScope
(
True
,
True
,
False
),
False
),
(
'messing_up'
,
LoRAScope
(),
True
)])
def
test_lora_scope_generator
(
self
,
scope_ref_assert
):
scope
,
reference
,
need_assert
=
scope_ref_assert
try
:
lora_scope
=
_canonicalize_lora_scope
(
scope
)
assert
lora_scope
==
reference
except
AssertionError
as
ae
:
assert
need_assert
,
f
"
{
ae
.
args
}
"
tests/jax/test_praxis_layers.py
View file @
7c1828f8
...
@@ -784,6 +784,7 @@ class MultiHeadAttnAttr:
...
@@ -784,6 +784,7 @@ class MultiHeadAttnAttr:
NUM_GQA_GROUPS
=
'num_gqa_groups'
NUM_GQA_GROUPS
=
'num_gqa_groups'
ENABLE_ROPE
=
'enable_rotary_pos_emb'
ENABLE_ROPE
=
'enable_rotary_pos_emb'
ROPE_GROUP_METHOD
=
'rotary_pos_emb_group_method'
ROPE_GROUP_METHOD
=
'rotary_pos_emb_group_method'
LORA_SCOPE
=
'low_rank_adaptation_scope'
ATTRS
=
[{
ATTRS
=
[{
USE_BIAS
:
True
,
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
LN_TYPE
:
'layernorm'
,
...
@@ -853,6 +854,22 @@ class MultiHeadAttnAttr:
...
@@ -853,6 +854,22 @@ class MultiHeadAttnAttr:
NUM_ATTN_HEADS
:
8
,
NUM_ATTN_HEADS
:
8
,
NUM_GQA_GROUPS
:
4
,
NUM_GQA_GROUPS
:
4
,
ATTN_MASK_TYPE
:
'causal'
ATTN_MASK_TYPE
:
'causal'
},
{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
'consecutive'
,
ATTN_MASK_TYPE
:
'padding'
,
LORA_SCOPE
:
'all'
},
{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
'consecutive'
,
ATTN_MASK_TYPE
:
'causal'
,
LORA_SCOPE
:
'all'
}]
}]
...
@@ -883,6 +900,7 @@ class TestMultiHeadAttn(TestLayer):
...
@@ -883,6 +900,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type
=
attrs
[
MultiHeadAttnAttr
.
ATTN_MASK_TYPE
]
attn_mask_type
=
attrs
[
MultiHeadAttnAttr
.
ATTN_MASK_TYPE
]
enable_rotary_pos_emb
=
attrs
[
MultiHeadAttnAttr
.
ENABLE_ROPE
]
enable_rotary_pos_emb
=
attrs
[
MultiHeadAttnAttr
.
ENABLE_ROPE
]
rotary_pos_emb_group_method
=
attrs
[
MultiHeadAttnAttr
.
ROPE_GROUP_METHOD
]
rotary_pos_emb_group_method
=
attrs
[
MultiHeadAttnAttr
.
ROPE_GROUP_METHOD
]
low_rank_adaptation_scope
=
attrs
.
get
(
MultiHeadAttnAttr
.
LORA_SCOPE
,
'none'
)
fuse_qkv_params
=
True
fuse_qkv_params
=
True
transpose_batch_sequence
=
True
transpose_batch_sequence
=
True
scale_attn_logits
=
False
scale_attn_logits
=
False
...
@@ -905,6 +923,7 @@ class TestMultiHeadAttn(TestLayer):
...
@@ -905,6 +923,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
fuse_qkv_params
=
fuse_qkv_params
,
fuse_qkv_params
=
fuse_qkv_params
,
transpose_batch_sequence
=
transpose_batch_sequence
,
transpose_batch_sequence
=
transpose_batch_sequence
,
scale_attn_logits
=
scale_attn_logits
,
scale_attn_logits
=
scale_attn_logits
,
...
@@ -926,6 +945,7 @@ class TestMultiHeadAttn(TestLayer):
...
@@ -926,6 +945,7 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type
=
attn_mask_type
,
attn_mask_type
=
attn_mask_type
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
fuse_qkv_params
=
fuse_qkv_params
,
fuse_qkv_params
=
fuse_qkv_params
,
transpose_batch_sequence
=
transpose_batch_sequence
,
transpose_batch_sequence
=
transpose_batch_sequence
,
scale_attn_logits
=
scale_attn_logits
,
scale_attn_logits
=
scale_attn_logits
,
...
@@ -969,6 +989,7 @@ class TransformerLayerAttr:
...
@@ -969,6 +989,7 @@ class TransformerLayerAttr:
TRANSPOSE_BS
=
'transpose_batch_sequence'
TRANSPOSE_BS
=
'transpose_batch_sequence'
ENABLE_ROPE
=
'enable_rotary_pos_emb'
ENABLE_ROPE
=
'enable_rotary_pos_emb'
ROPE_GROUP_METHOD
=
'rotary_pos_emb_group_method'
ROPE_GROUP_METHOD
=
'rotary_pos_emb_group_method'
LORA_SCOPE
=
'low_rank_adaptation_scope'
ATTRS
=
[{
ATTRS
=
[{
USE_BIAS
:
True
,
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
LN_TYPE
:
'layernorm'
,
...
@@ -1113,6 +1134,16 @@ class TransformerLayerAttr:
...
@@ -1113,6 +1134,16 @@ class TransformerLayerAttr:
ENABLE_ROPE
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
'consecutive'
,
ROPE_GROUP_METHOD
:
'consecutive'
,
TRANSPOSE_BS
:
False
TRANSPOSE_BS
:
False
},
{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
'gelu'
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
'consecutive'
,
TRANSPOSE_BS
:
False
,
LORA_SCOPE
:
'all'
},
{
},
{
USE_BIAS
:
True
,
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
LN_TYPE
:
'layernorm'
,
...
@@ -1185,6 +1216,16 @@ class TransformerLayerAttr:
...
@@ -1185,6 +1216,16 @@ class TransformerLayerAttr:
ENABLE_ROPE
:
True
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
'consecutive'
,
ROPE_GROUP_METHOD
:
'consecutive'
,
TRANSPOSE_BS
:
False
TRANSPOSE_BS
:
False
},
{
USE_BIAS
:
True
,
LN_TYPE
:
'layernorm'
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
'gelu'
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
'consecutive'
,
TRANSPOSE_BS
:
False
,
LORA_SCOPE
:
'all'
}]
}]
...
@@ -1219,6 +1260,7 @@ class TestTransformer(TestLayer):
...
@@ -1219,6 +1260,7 @@ class TestTransformer(TestLayer):
layer_type
=
attrs
[
TransformerLayerAttr
.
LYR_TYPE
]
layer_type
=
attrs
[
TransformerLayerAttr
.
LYR_TYPE
]
enable_rotary_pos_emb
=
attrs
[
TransformerLayerAttr
.
ENABLE_ROPE
]
enable_rotary_pos_emb
=
attrs
[
TransformerLayerAttr
.
ENABLE_ROPE
]
rotary_pos_emb_group_method
=
attrs
[
TransformerLayerAttr
.
ROPE_GROUP_METHOD
]
rotary_pos_emb_group_method
=
attrs
[
TransformerLayerAttr
.
ROPE_GROUP_METHOD
]
low_rank_adaptation_scope
=
attrs
.
get
(
TransformerLayerAttr
.
LORA_SCOPE
,
'none'
)
enable_relative_embedding
=
True
enable_relative_embedding
=
True
relative_embedding
=
pax_fiddle
.
Config
(
RelativePositionBiases
,
relative_embedding
=
pax_fiddle
.
Config
(
RelativePositionBiases
,
dtype
=
dtype
,
dtype
=
dtype
,
...
@@ -1257,6 +1299,7 @@ class TestTransformer(TestLayer):
...
@@ -1257,6 +1299,7 @@ class TestTransformer(TestLayer):
enable_relative_embedding
=
enable_relative_embedding
,
enable_relative_embedding
=
enable_relative_embedding
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
relative_embedding
=
relative_embedding
,
relative_embedding
=
relative_embedding
,
drop_path
=
drop_path
,
drop_path
=
drop_path
,
transpose_batch_sequence
=
transpose_batch_sequence
)
transpose_batch_sequence
=
transpose_batch_sequence
)
...
@@ -1282,6 +1325,7 @@ class TestTransformer(TestLayer):
...
@@ -1282,6 +1325,7 @@ class TestTransformer(TestLayer):
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
enable_relative_embedding
=
enable_relative_embedding
,
enable_relative_embedding
=
enable_relative_embedding
,
relative_embedding
=
relative_embedding_flax_module
,
relative_embedding
=
relative_embedding_flax_module
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
drop_path
=
drop_path
,
drop_path
=
drop_path
,
transpose_batch_sequence
=
transpose_batch_sequence
)
transpose_batch_sequence
=
transpose_batch_sequence
)
...
...
transformer_engine/jax/flax/module.py
View file @
7c1828f8
...
@@ -104,6 +104,31 @@ def _combine_biases(*masks: List[Array]):
...
@@ -104,6 +104,31 @@ def _combine_biases(*masks: List[Array]):
return
mask
return
mask
def
_apply_low_rank_adaptation
(
x
,
axis
,
features
,
lora_a_kernel
,
lora_b_kernel
,
alpha
):
"""Low Rank Adaptation Implementation"""
assert
len
(
axis
)
<=
5
hidden_in_names
=
'ijklm'
[:
len
(
axis
)]
assert
len
(
features
)
<=
5
hidden_out_names
=
'nopqr'
[:
len
(
features
)]
rank_name
=
's'
assert
lora_a_kernel
.
shape
[
-
1
]
==
lora_b_kernel
.
shape
[
-
2
]
rank
=
lora_a_kernel
.
shape
[
-
1
]
scaling
=
alpha
/
rank
if
alpha
is
not
None
else
1.0
x_einsum_express
=
f
"...
{
hidden_in_names
}
"
lora_a_einsum_express
=
f
"
{
hidden_in_names
}{
hidden_out_names
[:
-
1
]
}{
rank_name
}
"
lora_b_einsum_express
=
f
"
{
hidden_out_names
[:
-
1
]
}{
rank_name
}{
hidden_out_names
[
-
1
]
}
"
output_einsum_express
=
f
"...
{
hidden_out_names
}
"
final_einsum_express
=
f
"
{
x_einsum_express
}
,
{
lora_a_einsum_express
}
,
{
lora_b_einsum_express
}
"
\
f
"->
{
output_einsum_express
}
"
output
=
jnp
.
einsum
(
final_einsum_express
,
x
,
lora_a_kernel
,
lora_b_kernel
)
output
=
output
*
scaling
return
output
class
Softmax
(
nn
.
Module
):
# pylint: disable=too-few-public-methods
class
Softmax
(
nn
.
Module
):
# pylint: disable=too-few-public-methods
r
"""
r
"""
Applies softmax over a mini-batch of inputs.
Applies softmax over a mini-batch of inputs.
...
@@ -355,6 +380,14 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -355,6 +380,14 @@ class DenseGeneral(TransformerEngineBase):
bias_axes: Tuple[str, ...], default = ()
bias_axes: Tuple[str, ...], default = ()
The name of axes used to shard bias with a corresponding mesh,
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
only used when :attr:`use_bias=True`.
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`
\f
rac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
An integer tuple with axes to apply the transformation on.
...
@@ -374,6 +407,9 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -374,6 +407,9 @@ class DenseGeneral(TransformerEngineBase):
use_bias
:
bool
=
True
use_bias
:
bool
=
True
bias_init
:
Initializer
=
nn
.
initializers
.
zeros
bias_init
:
Initializer
=
nn
.
initializers
.
zeros
bias_axes
:
Tuple
[
str
,
...]
=
()
bias_axes
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
dtype
:
DType
=
jnp
.
float32
dtype
:
DType
=
jnp
.
float32
transpose_batch_sequence
:
bool
=
False
transpose_batch_sequence
:
bool
=
False
...
@@ -439,6 +475,32 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -439,6 +475,32 @@ class DenseGeneral(TransformerEngineBase):
fp8_meta_pkg
=
fp8_gemm_pkg
,
fp8_meta_pkg
=
fp8_gemm_pkg
,
contracting_dims
=
(
axis
,
contract_ind
))
contracting_dims
=
(
axis
,
contract_ind
))
if
self
.
enable_low_rank_adaptation
:
lora_a_kernel_shape
=
(
*
kernel_shape
[:
len
(
axis
)],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
)
lora_a_kernel_init_shape
=
(
kernel_param_shape
[
0
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_init_shape
)
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
'lora_a_kernel'
,
self
.
kernel_init
,
lora_a_kernel_init_shape
,
jnp
.
float32
,
axes
=
lora_a_kernel_axes
)
lora_a_kernel
=
jnp
.
reshape
(
lora_a_kernel
,
lora_a_kernel_shape
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
self
.
dtype
)
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
lora_b_kernel_axes
=
(
None
,)
*
len
(
lora_b_kernel_shape
)
lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
'lora_b_kernel'
,
nn
.
initializers
.
zeros
,
lora_b_kernel_shape
,
jnp
.
float32
,
axes
=
lora_b_kernel_axes
)
lora_b_kernel
=
lora_b_kernel
.
astype
(
self
.
dtype
)
y
+=
_apply_low_rank_adaptation
(
inputs
,
axis
,
features
,
lora_a_kernel
,
lora_b_kernel
,
self
.
low_rank_adaptation_alpha
)
if
bias
is
not
None
:
if
bias
is
not
None
:
bias_shape
=
(
1
,)
*
(
y
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
bias_shape
=
(
1
,)
*
(
y
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
y
+=
jnp
.
reshape
(
bias
,
bias_shape
)
y
+=
jnp
.
reshape
(
bias
,
bias_shape
)
...
@@ -502,6 +564,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -502,6 +564,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
return_layernorm_output: bool, default = True
return_layernorm_output: bool, default = True
Indicate whether to return the output of layer normalization.
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
layernorm_input_axes: Tuple[str, ...], default = None
...
@@ -541,6 +611,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -541,6 +611,9 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias_init
:
Initializer
=
nn
.
initializers
.
zeros
bias_init
:
Initializer
=
nn
.
initializers
.
zeros
bias_axes
:
Tuple
[
str
,
...]
=
()
bias_axes
:
Tuple
[
str
,
...]
=
()
return_layernorm_output
:
bool
=
True
return_layernorm_output
:
bool
=
True
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
dtype
:
DType
=
jnp
.
float32
dtype
:
DType
=
jnp
.
float32
transpose_batch_sequence
:
bool
=
True
transpose_batch_sequence
:
bool
=
True
...
@@ -650,6 +723,32 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -650,6 +723,32 @@ class LayerNormDenseGeneral(TransformerEngineBase):
fp8_meta_pkg
=
fp8_meta_package
,
fp8_meta_pkg
=
fp8_meta_package
,
contracting_dims
=
(
axis
,
contract_ind
))
contracting_dims
=
(
axis
,
contract_ind
))
if
self
.
enable_low_rank_adaptation
:
lora_a_kernel_shape
=
(
*
kernel_shape
[:
len
(
axis
)],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
)
lora_a_kernel_init_shape
=
(
kernel_param_shape
[
0
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_init_shape
)
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
'lora_a_kernel'
,
self
.
kernel_init
,
lora_a_kernel_init_shape
,
jnp
.
float32
,
axes
=
lora_a_kernel_axes
)
lora_a_kernel
=
jnp
.
reshape
(
lora_a_kernel
,
lora_a_kernel_shape
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
self
.
dtype
)
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
lora_b_kernel_axes
=
(
None
,)
*
len
(
lora_b_kernel_shape
)
lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
'lora_b_kernel'
,
nn
.
initializers
.
zeros
,
lora_b_kernel_shape
,
jnp
.
float32
,
axes
=
lora_b_kernel_axes
)
lora_b_kernel
=
lora_b_kernel
.
astype
(
self
.
dtype
)
z
+=
_apply_low_rank_adaptation
(
y
,
axis
,
features
,
lora_a_kernel
,
lora_b_kernel
,
self
.
low_rank_adaptation_alpha
)
bias
=
None
bias
=
None
if
self
.
use_bias
:
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
'bias'
,
bias
=
nn_partitioning
.
param_with_axes
(
'bias'
,
...
@@ -745,6 +844,14 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -745,6 +844,14 @@ class LayerNormMLP(TransformerEngineBase):
Dropout probability for the dropout op after the :attr:`activations`.
Dropout probability for the dropout op after the :attr:`activations`.
intermediate_hidden_dropout_dims: Sequence[int], default = ()
intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden
Dimensions that will share the same dropout mask for hidden
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each linear layer.
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`.
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
An integer tuple with axes to apply the transformation on.
layernorm_input_axes: Tuple[str, ...], default = None
layernorm_input_axes: Tuple[str, ...], default = None
...
@@ -791,6 +898,9 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -791,6 +898,9 @@ class LayerNormMLP(TransformerEngineBase):
intermediate_dropout_rng_name
:
str
=
'dropout'
intermediate_dropout_rng_name
:
str
=
'dropout'
intermediate_dropout_rate
:
float
=
0.1
intermediate_dropout_rate
:
float
=
0.1
intermediate_hidden_dropout_dims
:
Sequence
[
int
]
=
()
intermediate_hidden_dropout_dims
:
Sequence
[
int
]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
dtype
:
DType
=
jnp
.
float32
dtype
:
DType
=
jnp
.
float32
transpose_batch_sequence
:
bool
=
True
transpose_batch_sequence
:
bool
=
True
...
@@ -856,11 +966,13 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -856,11 +966,13 @@ class LayerNormMLP(TransformerEngineBase):
use_fused_ln_geglu_mlp
=
fuse_layernorm
\
use_fused_ln_geglu_mlp
=
fuse_layernorm
\
and
(
not
self
.
use_bias
)
and
is_geglu
(
self
.
activations
)
\
and
(
not
self
.
use_bias
)
and
is_geglu
(
self
.
activations
)
\
and
(
self
.
intermediate_dropout_rate
<
1e-3
)
and
(
self
.
intermediate_dropout_rate
<
1e-3
)
\
and
not
self
.
enable_low_rank_adaptation
use_fused_ln_gelu_mlp
=
fuse_layernorm
\
use_fused_ln_gelu_mlp
=
fuse_layernorm
\
and
self
.
use_bias
and
is_gelu
(
self
.
activations
)
\
and
self
.
use_bias
and
is_gelu
(
self
.
activations
)
\
and
(
self
.
intermediate_dropout_rate
<
1e-3
)
and
(
self
.
intermediate_dropout_rate
<
1e-3
)
\
and
not
self
.
enable_low_rank_adaptation
# LayerNorm
# LayerNorm
if
self
.
enable_layernorm
:
if
self
.
enable_layernorm
:
...
@@ -999,6 +1111,37 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -999,6 +1111,37 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_pkg
=
gemm1_fp8_meta_package
,
fp8_meta_pkg
=
gemm1_fp8_meta_package
,
contracting_dims
=
(
axis
,
contract_ind
))
contracting_dims
=
(
axis
,
contract_ind
))
if
self
.
enable_low_rank_adaptation
:
wi_lora_a_kernel_shape
=
(
*
kernel_1_shape
[:
len
(
axis
)],
num_activations
,
self
.
low_rank_adaptation_dim
)
wi_lora_a_kernel_init_shape
=
(
kernel_1_each_shape
[
0
],
num_activations
,
self
.
low_rank_adaptation_dim
)
wi_lora_a_kernel_init_each_shape
=
(
kernel_1_each_shape
[
0
],
self
.
low_rank_adaptation_dim
)
wi_lora_a_kernel_axes
=
(
None
,)
*
len
(
wi_lora_a_kernel_init_shape
)
wi_lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
'wi_lora_a_kernel'
,
kernel_1_init
,
num_activations
,
-
2
,
wi_lora_a_kernel_init_each_shape
,
jnp
.
float32
,
axes
=
wi_lora_a_kernel_axes
)
wi_lora_a_kernel
=
jnp
.
reshape
(
wi_lora_a_kernel
,
wi_lora_a_kernel_shape
)
wi_lora_a_kernel
=
wi_lora_a_kernel
.
astype
(
self
.
dtype
)
wi_lora_b_kernel_shape
=
(
num_activations
,
self
.
low_rank_adaptation_dim
,
self
.
intermediate_dim
)
wi_lora_b_kernel_axes
=
(
None
,)
*
len
(
wi_lora_b_kernel_shape
)
wi_lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
'wi_lora_b_kernel'
,
nn
.
initializers
.
zeros
,
wi_lora_b_kernel_shape
,
jnp
.
float32
,
axes
=
wi_lora_b_kernel_axes
)
wi_lora_b_kernel
=
wi_lora_b_kernel
.
astype
(
self
.
dtype
)
x
+=
_apply_low_rank_adaptation
(
y
,
axis
,
intermediate_dim
,
wi_lora_a_kernel
,
wi_lora_b_kernel
,
self
.
low_rank_adaptation_alpha
)
bias
=
None
bias
=
None
if
self
.
use_bias
:
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
'wi_bias'
,
bias
=
nn_partitioning
.
param_with_axes
(
'wi_bias'
,
...
@@ -1042,6 +1185,28 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1042,6 +1185,28 @@ class LayerNormMLP(TransformerEngineBase):
fp8_meta_pkg
=
gemm2_fp8_meta_package
,
fp8_meta_pkg
=
gemm2_fp8_meta_package
,
contracting_dims
=
(
axis
,
contract_ind
))
contracting_dims
=
(
axis
,
contract_ind
))
if
self
.
enable_low_rank_adaptation
:
wo_lora_a_kernel_shape
=
(
self
.
intermediate_dim
,
self
.
low_rank_adaptation_dim
)
wo_lora_a_kernel_axes
=
(
None
,)
*
len
(
wo_lora_a_kernel_shape
)
wo_lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
'wo_lora_a_kernel'
,
self
.
kernel_init
,
wo_lora_a_kernel_shape
,
jnp
.
float32
,
axes
=
wo_lora_a_kernel_axes
)
wo_lora_a_kernel
=
wo_lora_a_kernel
.
astype
(
self
.
dtype
)
wo_lora_b_kernel_shape
=
(
self
.
low_rank_adaptation_dim
,
hidden_size
)
wo_lora_b_kernel_axes
=
(
None
,)
*
len
(
wo_lora_b_kernel_shape
)
wo_lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
'wo_lora_b_kernel'
,
nn
.
initializers
.
zeros
,
wo_lora_b_kernel_shape
,
jnp
.
float32
,
axes
=
wo_lora_b_kernel_axes
)
wo_lora_b_kernel
=
wo_lora_b_kernel
.
astype
(
self
.
dtype
)
out
+=
_apply_low_rank_adaptation
(
z
,
axis
,
hidden_size_tuple
,
wo_lora_a_kernel
,
wo_lora_b_kernel
,
self
.
low_rank_adaptation_alpha
)
bias
=
None
bias
=
None
if
self
.
use_bias
:
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
'wo_bias'
,
bias
=
nn_partitioning
.
param_with_axes
(
'wo_bias'
,
...
...
transformer_engine/jax/flax/transformer.py
View file @
7c1828f8
...
@@ -637,6 +637,53 @@ def rotary_pos_emb(x: Array,
...
@@ -637,6 +637,53 @@ def rotary_pos_emb(x: Array,
return
consecutive_impl
()
return
consecutive_impl
()
class
LoRAScope
:
# pylint: disable=too-few-public-methods
"""LoRA Scope"""
def
__init__
(
self
,
qkv_proj
=
False
,
output_proj
=
False
,
mlp
=
False
):
self
.
qkv_proj
=
qkv_proj
self
.
output_proj
=
output_proj
self
.
mlp
=
mlp
def
__eq__
(
self
,
other
):
return
(
self
.
qkv_proj
,
self
.
output_proj
,
self
.
mlp
)
==
\
(
other
.
qkv_proj
,
other
.
output_proj
,
other
.
mlp
)
def
_canonicalize_lora_scope
(
scope
):
SCOPE_NONE
=
'none'
SCOPE_ALL
=
'all'
SCOPE_QKV_PROJ
=
'qkv_proj'
SCOPE_OUTPUT_PROJ
=
'output_proj'
SCOPE_MLP
=
'mlp'
SCOPE_EX_QKV_PROJ
=
'exclude_qkv_proj'
SCOPE_EX_OUTPUT_PROJ
=
'exclude_output_proj'
SCOPE_EX_MLP
=
'exclude_mlp'
scope
=
SCOPE_NONE
if
scope
is
None
else
scope
scope
=
scope
.
lower
()
assert
scope
in
[
SCOPE_NONE
,
SCOPE_ALL
,
SCOPE_QKV_PROJ
,
SCOPE_OUTPUT_PROJ
,
SCOPE_MLP
,
SCOPE_EX_QKV_PROJ
,
SCOPE_EX_OUTPUT_PROJ
,
SCOPE_EX_MLP
]
lora_scope
=
LoRAScope
()
if
scope
in
[
SCOPE_ALL
,
SCOPE_QKV_PROJ
,
SCOPE_EX_OUTPUT_PROJ
,
SCOPE_EX_MLP
]:
lora_scope
.
qkv_proj
=
True
if
scope
in
[
SCOPE_ALL
,
SCOPE_OUTPUT_PROJ
,
SCOPE_EX_QKV_PROJ
,
SCOPE_EX_MLP
]:
lora_scope
.
output_proj
=
True
if
scope
in
[
SCOPE_ALL
,
SCOPE_MLP
,
SCOPE_EX_QKV_PROJ
,
SCOPE_EX_OUTPUT_PROJ
]:
lora_scope
.
mlp
=
True
return
lora_scope
class
MultiHeadAttention
(
nn
.
Module
):
# pylint: disable=too-few-public-methods
class
MultiHeadAttention
(
nn
.
Module
):
# pylint: disable=too-few-public-methods
r
"""
r
"""
Multi-head Attention (MHA), including Query,
Multi-head Attention (MHA), including Query,
...
@@ -723,6 +770,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -723,6 +770,15 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Indicate the method to coupled the coordinates. It should be one of
Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'exclude_qkv_proj', 'exclude_output_proj']
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
enable_sequence_parallel: bool, default = False
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
Whether to enable sequence parallelism to operations except dot.
num_heads: int, default = None
num_heads: int, default = None
...
@@ -777,6 +833,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -777,6 +833,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb
:
bool
=
False
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
'consecutive'
rotary_pos_emb_group_method
:
str
=
'consecutive'
low_rank_adaptation_scope
:
str
=
'none'
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
dtype
:
DType
=
jnp
.
float32
dtype
:
DType
=
jnp
.
float32
fuse_qkv_params
:
bool
=
True
fuse_qkv_params
:
bool
=
True
transpose_batch_sequence
:
bool
=
True
transpose_batch_sequence
:
bool
=
True
...
@@ -914,6 +973,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -914,6 +973,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
inputs_q
=
with_sharding_constraint_by_logical_axes
(
inputs_q
,
inputs_logical_axes_maybe_sp
)
inputs_q
=
with_sharding_constraint_by_logical_axes
(
inputs_q
,
inputs_logical_axes_maybe_sp
)
lora_scope
=
_canonicalize_lora_scope
(
self
.
low_rank_adaptation_scope
)
if
self
.
fuse_qkv_params
:
if
self
.
fuse_qkv_params
:
if
is_qkvpack
:
if
is_qkvpack
:
qkv_proj
,
ln_out
=
LayerNormDenseGeneral
(
qkv_proj
,
ln_out
=
LayerNormDenseGeneral
(
...
@@ -932,6 +993,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -932,6 +993,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_JOINED_AXES
,
W_TP_AXES
),
bias_axes
=
(
W_JOINED_AXES
,
W_TP_AXES
),
enable_low_rank_adaptation
=
lora_scope
.
qkv_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
layernorm_input_axes
=
inputs_logical_axes_maybe_sp
,
layernorm_input_axes
=
inputs_logical_axes_maybe_sp
,
dot_input_axes
=
inputs_logical_axes_no_sp
,
dot_input_axes
=
inputs_logical_axes_no_sp
,
name
=
'qkv'
,
name
=
'qkv'
,
...
@@ -954,6 +1018,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -954,6 +1018,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_TP_AXES
,),
bias_axes
=
(
W_TP_AXES
,),
enable_low_rank_adaptation
=
lora_scope
.
qkv_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
kernel_init
=
query_init
,
kernel_init
=
query_init
,
layernorm_input_axes
=
inputs_logical_axes_maybe_sp
,
layernorm_input_axes
=
inputs_logical_axes_maybe_sp
,
...
@@ -972,6 +1039,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -972,6 +1039,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_JOINED_AXES
,
W_TP_AXES
),
bias_axes
=
(
W_JOINED_AXES
,
W_TP_AXES
),
enable_low_rank_adaptation
=
lora_scope
.
qkv_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
name
=
'kv'
,
name
=
'kv'
,
dtype
=
self
.
dtype
)(
inputs_kv
)
dtype
=
self
.
dtype
)(
inputs_kv
)
kv_proj
=
checkpoint_name
(
kv_proj
,
'combined_kv_proj'
)
kv_proj
=
checkpoint_name
(
kv_proj
,
'combined_kv_proj'
)
...
@@ -986,6 +1056,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -986,6 +1056,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_TP_AXES
,),
bias_axes
=
(
W_TP_AXES
,),
enable_low_rank_adaptation
=
lora_scope
.
qkv_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
dtype
=
self
.
dtype
)
dtype
=
self
.
dtype
)
query
,
ln_out
=
LayerNormDenseGeneral
(
query
,
ln_out
=
LayerNormDenseGeneral
(
enable_layernorm
=
self
.
input_layernorm
,
enable_layernorm
=
self
.
input_layernorm
,
...
@@ -1002,6 +1075,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -1002,6 +1075,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_TP_AXES
,),
bias_axes
=
(
W_TP_AXES
,),
enable_low_rank_adaptation
=
lora_scope
.
qkv_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
kernel_init
=
query_init
,
kernel_init
=
query_init
,
layernorm_input_axes
=
inputs_logical_axes_maybe_sp
,
layernorm_input_axes
=
inputs_logical_axes_maybe_sp
,
...
@@ -1142,6 +1218,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -1142,6 +1218,9 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
use_bias
=
self
.
use_bias
,
use_bias
=
self
.
use_bias
,
bias_init
=
self
.
bias_init
,
bias_init
=
self
.
bias_init
,
bias_axes
=
(
W_NO_SHARD_AXES
,),
bias_axes
=
(
W_NO_SHARD_AXES
,),
enable_low_rank_adaptation
=
lora_scope
.
output_proj
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
name
=
'out'
)(
x
)
name
=
'out'
)(
x
)
out
=
checkpoint_name
(
out
,
'out_proj'
)
out
=
checkpoint_name
(
out
,
'out_proj'
)
...
@@ -1379,6 +1458,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -1379,6 +1458,16 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
Indicate the method to coupled the coordinates. It should be one of
Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
low_rank_adaptation_scope: str, default = 'none'
Indicate the scope to apply low rank adaptation. It should be one of
['none', 'all', 'qkv_proj', 'output_proj', 'mlp', 'exclude_qkv_proj',
'exclude_output_proj', 'exclude_mlp']
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
low_rank_adaptation_alpha: float, default = None
The alpha for computing the scaling factor of LoRA output.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
enable_sequence_parallel: bool, default = False
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
Whether to enable sequence parallelism to operations except dot.
...
@@ -1434,6 +1523,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -1434,6 +1523,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb
:
bool
=
False
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
'consecutive'
rotary_pos_emb_group_method
:
str
=
'consecutive'
low_rank_adaptation_scope
:
str
=
'none'
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
dtype
:
DType
=
jnp
.
float32
dtype
:
DType
=
jnp
.
float32
drop_path
:
float
=
0.0
drop_path
:
float
=
0.0
fuse_qkv_params
:
bool
=
True
fuse_qkv_params
:
bool
=
True
...
@@ -1579,6 +1671,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -1579,6 +1671,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
fuse_qkv_params
=
self
.
fuse_qkv_params
,
fuse_qkv_params
=
self
.
fuse_qkv_params
,
kernel_init
=
self
.
mha_kernel_init
,
kernel_init
=
self
.
mha_kernel_init
,
use_bias
=
self
.
use_bias
,
use_bias
=
self
.
use_bias
,
...
@@ -1646,6 +1741,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -1646,6 +1741,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
float32_logits
=
self
.
float32_attention_logits
,
float32_logits
=
self
.
float32_attention_logits
,
scale_attn_logits
=
self
.
scale_attn_logits
,
scale_attn_logits
=
self
.
scale_attn_logits
,
scaled_query_init
=
self
.
scaled_query_init
,
scaled_query_init
=
self
.
scaled_query_init
,
...
@@ -1674,6 +1772,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -1674,6 +1772,8 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
mlp_input
=
with_sharding_constraint_by_logical_axes
(
mlp_input
=
with_sharding_constraint_by_logical_axes
(
mlp_input
,
(
*
generate_batch_seqlen_logical_axes
(),
HIDDEN_AXES
))
mlp_input
,
(
*
generate_batch_seqlen_logical_axes
(),
HIDDEN_AXES
))
lora_scope
=
_canonicalize_lora_scope
(
self
.
low_rank_adaptation_scope
)
# MlpBlock
# MlpBlock
residual
=
mlp_input
residual
=
mlp_input
z
,
ln_out
=
LayerNormMLP
(
z
,
ln_out
=
LayerNormMLP
(
...
@@ -1697,6 +1797,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -1697,6 +1797,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
bias_init
=
self
.
bias_init
,
bias_init
=
self
.
bias_init
,
bias_axes_1
=
(
W_JOINED_AXES
,
W_TP_AXES
),
bias_axes_1
=
(
W_JOINED_AXES
,
W_TP_AXES
),
bias_axes_2
=
(
W_NO_SHARD_AXES
,),
bias_axes_2
=
(
W_NO_SHARD_AXES
,),
enable_low_rank_adaptation
=
lora_scope
.
mlp
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
layernorm_input_axes
=
(
*
generate_batch_seqlen_logical_axes
(),
HIDDEN_AXES
),
layernorm_input_axes
=
(
*
generate_batch_seqlen_logical_axes
(),
HIDDEN_AXES
),
dot_1_input_axes
=
(
*
generate_batch_seqlen_logical_axes
(
False
),
HIDDEN_AXES
),
dot_1_input_axes
=
(
*
generate_batch_seqlen_logical_axes
(
False
),
HIDDEN_AXES
),
dot_2_input_axes
=
(
*
generate_batch_seqlen_logical_axes
(
False
),
HIDDEN_TP_AXES
),
dot_2_input_axes
=
(
*
generate_batch_seqlen_logical_axes
(
False
),
HIDDEN_TP_AXES
),
...
...
transformer_engine/jax/praxis/module.py
View file @
7c1828f8
...
@@ -131,6 +131,9 @@ class Linear(TransformerEngineBaseLayer):
...
@@ -131,6 +131,9 @@ class Linear(TransformerEngineBaseLayer):
use_bias
:
bool
=
True
use_bias
:
bool
=
True
bias_init
:
WeightInit
=
WeightInit
.
Constant
(
0.0
)
bias_init
:
WeightInit
=
WeightInit
.
Constant
(
0.0
)
bias_axes
:
Tuple
[
str
,
...]
=
()
bias_axes
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
transpose_batch_sequence
:
bool
=
False
transpose_batch_sequence
:
bool
=
False
sharding_type
:
ShardingType
=
ShardingType
.
SINGLE
sharding_type
:
ShardingType
=
ShardingType
.
SINGLE
...
@@ -147,6 +150,9 @@ class Linear(TransformerEngineBaseLayer):
...
@@ -147,6 +150,9 @@ class Linear(TransformerEngineBaseLayer):
use_bias
=
self
.
use_bias
,
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes
=
self
.
bias_axes
,
bias_axes
=
self
.
bias_axes
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
axis
=
self
.
axis
,
axis
=
self
.
axis
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
)
transpose_batch_sequence
=
self
.
transpose_batch_sequence
)
...
@@ -174,6 +180,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
...
@@ -174,6 +180,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
use_bias
:
bool
=
False
use_bias
:
bool
=
False
bias_init
:
WeightInit
=
WeightInit
.
Constant
(
0.0
)
bias_init
:
WeightInit
=
WeightInit
.
Constant
(
0.0
)
bias_axes
:
Tuple
[
str
,
...]
=
()
bias_axes
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
return_layernorm_output
:
bool
=
True
return_layernorm_output
:
bool
=
True
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
transpose_batch_sequence
:
bool
=
False
transpose_batch_sequence
:
bool
=
False
...
@@ -201,6 +210,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
...
@@ -201,6 +210,9 @@ class LayerNormLinear(TransformerEngineBaseLayer):
use_bias
=
self
.
use_bias
,
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes
=
self
.
bias_axes
,
bias_axes
=
self
.
bias_axes
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
return_layernorm_output
=
self
.
return_layernorm_output
,
return_layernorm_output
=
self
.
return_layernorm_output
,
axis
=
self
.
axis
,
axis
=
self
.
axis
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -232,6 +244,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
...
@@ -232,6 +244,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
bias_init
:
WeightInit
=
WeightInit
.
Constant
(
0.0
)
bias_init
:
WeightInit
=
WeightInit
.
Constant
(
0.0
)
bias_axes_1
:
Tuple
[
str
,
...]
=
()
bias_axes_1
:
Tuple
[
str
,
...]
=
()
bias_axes_2
:
Tuple
[
str
,
...]
=
()
bias_axes_2
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
return_layernorm_output
:
bool
=
True
return_layernorm_output
:
bool
=
True
activations
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
'relu'
,)
activations
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
'relu'
,)
intermediate_dropout_rate
:
float
=
0.1
intermediate_dropout_rate
:
float
=
0.1
...
@@ -263,6 +278,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
...
@@ -263,6 +278,9 @@ class LayerNormMLP(TransformerEngineBaseLayer):
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes_1
=
self
.
bias_axes_1
,
bias_axes_1
=
self
.
bias_axes_1
,
bias_axes_2
=
self
.
bias_axes_2
,
bias_axes_2
=
self
.
bias_axes_2
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
return_layernorm_output
=
self
.
return_layernorm_output
,
return_layernorm_output
=
self
.
return_layernorm_output
,
activations
=
self
.
activations
,
activations
=
self
.
activations
,
intermediate_dropout_rate
=
self
.
intermediate_dropout_rate
,
intermediate_dropout_rate
=
self
.
intermediate_dropout_rate
,
...
...
transformer_engine/jax/praxis/transformer.py
View file @
7c1828f8
...
@@ -137,6 +137,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
...
@@ -137,6 +137,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
enable_rotary_pos_emb
:
bool
=
False
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
'consecutive'
rotary_pos_emb_group_method
:
str
=
'consecutive'
low_rank_adaptation_scope
:
str
=
'none'
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
fuse_qkv_params
:
bool
=
True
fuse_qkv_params
:
bool
=
True
transpose_batch_sequence
:
bool
=
True
transpose_batch_sequence
:
bool
=
True
enable_sequence_parallel
:
bool
=
False
enable_sequence_parallel
:
bool
=
False
...
@@ -208,6 +211,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
...
@@ -208,6 +211,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
fuse_qkv_params
=
self
.
fuse_qkv_params
,
fuse_qkv_params
=
self
.
fuse_qkv_params
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
enable_sequence_parallel
=
self
.
enable_sequence_parallel
,
enable_sequence_parallel
=
self
.
enable_sequence_parallel
,
...
@@ -262,6 +268,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
...
@@ -262,6 +268,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_rotary_pos_emb
:
bool
=
False
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
'consecutive'
rotary_pos_emb_group_method
:
str
=
'consecutive'
low_rank_adaptation_scope
:
str
=
'none'
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
enable_relative_embedding
:
bool
=
True
enable_relative_embedding
:
bool
=
True
relative_embedding
:
pax_fiddle
.
Config
[
RelativePositionBiases
]
=
pax_fiddle
.
template_field
(
None
)
relative_embedding
:
pax_fiddle
.
Config
[
RelativePositionBiases
]
=
pax_fiddle
.
template_field
(
None
)
drop_path
:
float
=
0.0
drop_path
:
float
=
0.0
...
@@ -332,6 +341,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
...
@@ -332,6 +341,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
enable_relative_embedding
=
self
.
enable_relative_embedding
,
enable_relative_embedding
=
self
.
enable_relative_embedding
,
relative_embedding
=
relative_embedding_flax_module
,
relative_embedding
=
relative_embedding_flax_module
,
drop_path
=
self
.
drop_path
,
drop_path
=
self
.
drop_path
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment