Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a207db1d
Commit
a207db1d
authored
Apr 01, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
fbee8990
69365f88
Changes
101
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1659 additions
and
2477 deletions
+1659
-2477
docs/api/pytorch.rst
docs/api/pytorch.rst
+1
-0
examples/jax/encoder/common.py
examples/jax/encoder/common.py
+20
-0
examples/jax/encoder/run_test_multiprocessing_encoder.sh
examples/jax/encoder/run_test_multiprocessing_encoder.sh
+7
-1
examples/jax/encoder/test_model_parallel_encoder.py
examples/jax/encoder/test_model_parallel_encoder.py
+63
-22
examples/jax/encoder/test_multigpu_encoder.py
examples/jax/encoder/test_multigpu_encoder.py
+45
-15
examples/jax/encoder/test_multiprocessing_encoder.py
examples/jax/encoder/test_multiprocessing_encoder.py
+59
-31
examples/jax/encoder/test_single_gpu_encoder.py
examples/jax/encoder/test_single_gpu_encoder.py
+31
-10
examples/jax/mnist/test_single_gpu_mnist.py
examples/jax/mnist/test_single_gpu_mnist.py
+41
-10
qa/L0_jax_unittest/test.sh
qa/L0_jax_unittest/test.sh
+3
-4
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-1
qa/L2_jax_unittest/test.sh
qa/L2_jax_unittest/test.sh
+23
-0
tests/jax/distributed_test_base.py
tests/jax/distributed_test_base.py
+1
-1
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+1107
-747
tests/jax/test_distributed_fused_attn.py
tests/jax/test_distributed_fused_attn.py
+4
-5
tests/jax/test_distributed_layernorm.py
tests/jax/test_distributed_layernorm.py
+65
-21
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+102
-99
tests/jax/test_distributed_softmax.py
tests/jax/test_distributed_softmax.py
+1
-1
tests/jax/test_helper.py
tests/jax/test_helper.py
+22
-22
tests/jax/test_layer.py
tests/jax/test_layer.py
+63
-51
tests/jax/test_praxis_layers.py
tests/jax/test_praxis_layers.py
+0
-1436
No files found.
docs/api/pytorch.rst
View file @
a207db1d
...
@@ -32,6 +32,7 @@ pyTorch
...
@@ -32,6 +32,7 @@ pyTorch
:members: forward, set_context_parallel_group, set_tensor_parallel_group
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
:members: reset, allocate_memory, pre_step, get_seqlens_pre_step, convert_paged_to_nonpaged, step
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
:members: reset, get_states, set_states, add, fork
:members: reset, get_states, set_states, add, fork
...
...
examples/jax/encoder/common.py
View file @
a207db1d
...
@@ -4,7 +4,9 @@
...
@@ -4,7 +4,9 @@
"""Shared functions for the encoder tests"""
"""Shared functions for the encoder tests"""
from
functools
import
lru_cache
from
functools
import
lru_cache
import
transformer_engine
from
transformer_engine_jax
import
get_device_compute_capability
from
transformer_engine_jax
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
@
lru_cache
@
lru_cache
...
@@ -19,3 +21,21 @@ def is_fp8_supported():
...
@@ -19,3 +21,21 @@ def is_fp8_supported():
"""Return if FP8 has hardware supported"""
"""Return if FP8 has hardware supported"""
gpu_arch
=
get_device_compute_capability
(
0
)
gpu_arch
=
get_device_compute_capability
(
0
)
return
gpu_arch
>=
90
return
gpu_arch
>=
90
@
lru_cache
def
is_mxfp8_supported
():
"""Return if FP8 has hardware supported"""
gpu_arch
=
get_device_compute_capability
(
0
)
return
gpu_arch
>=
100
def
get_fp8_recipe_from_name_string
(
name
:
str
):
"""Query recipe from a given name string"""
match
name
:
case
"DelayedScaling"
:
return
recipe
.
DelayedScaling
()
case
"MXFP8BlockScaling"
:
return
recipe
.
MXFP8BlockScaling
()
case
_
:
raise
ValueError
(
f
"Invalid fp8_recipe, got
{
name
}
"
)
examples/jax/encoder/run_test_multiprocessing_encoder.sh
View file @
a207db1d
...
@@ -12,6 +12,12 @@ wait
...
@@ -12,6 +12,12 @@ wait
for
i
in
$(
seq
0
$((
$NUM_GPUS
-
1
))
)
for
i
in
$(
seq
0
$((
$NUM_GPUS
-
1
))
)
do
do
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8
--num-process
=
$NUM_GPUS
--process-id
=
$i
&
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8
--num-process
=
$NUM_GPUS
--process-id
=
$i
&
done
wait
for
i
in
$(
seq
0
$((
$NUM_GPUS
-
1
))
)
do
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8
--num-process
=
$NUM_GPUS
--process-id
=
$i
&
done
done
wait
wait
examples/jax/encoder/test_model_parallel_encoder.py
View file @
a207db1d
...
@@ -19,10 +19,11 @@ from flax.training import train_state
...
@@ -19,10 +19,11 @@ from flax.training import train_state
from
jax.experimental
import
mesh_utils
from
jax.experimental
import
mesh_utils
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
import
transformer_engine.jax
as
te
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
from
common
import
is_bf16_supported
DEVICE_DP_AXIS
=
"data"
DEVICE_DP_AXIS
=
"data"
DEVICE_TP_AXIS
=
"model"
DEVICE_TP_AXIS
=
"model"
...
@@ -217,9 +218,8 @@ def get_datasets(max_seq_len):
...
@@ -217,9 +218,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
)
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
...
@@ -272,6 +272,19 @@ def train_and_evaluate(args):
...
@@ -272,6 +272,19 @@ def train_and_evaluate(args):
args
.
test_batch_size
%
num_gpu_dp
==
0
args
.
test_batch_size
%
num_gpu_dp
==
0
),
f
"Test batch size needs to be multiple of
{
num_gpu_dp
}
"
),
f
"Test batch size needs to be multiple of
{
num_gpu_dp
}
"
if
args
.
fp8_recipe
==
"MXFP8BlockScaling"
:
assert
(
args
.
batch_size
/
num_gpu_dp
%
32
==
0
),
"Batch size needs to be multiple of 32 for MXFP8"
assert
(
args
.
test_batch_size
/
num_gpu_dp
%
32
==
0
),
"Test batch size needs to be multiple of 32 for MXFP8"
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
with
jax
.
sharding
.
Mesh
(
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
)
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
)
...
@@ -287,7 +300,9 @@ def train_and_evaluate(args):
...
@@ -287,7 +300,9 @@ def train_and_evaluate(args):
label_shape
=
[
args
.
batch_size
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
with
te
.
fp8_autocast
(
args
.
use_fp8
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
)
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
):
):
encoder
=
Net
(
num_embed
,
args
.
enable_sp
)
encoder
=
Net
(
num_embed
,
args
.
enable_sp
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
@@ -371,21 +386,21 @@ def encoder_parser(args):
...
@@ -371,21 +386,21 @@ def encoder_parser(args):
parser
.
add_argument
(
parser
.
add_argument
(
"--batch-size"
,
"--batch-size"
,
type
=
int
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
metavar
=
"N"
,
help
=
"input batch size for training (default:
64
)"
,
help
=
"input batch size for training (default:
128
)"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--test-batch-size"
,
"--test-batch-size"
,
type
=
int
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
metavar
=
"N"
,
help
=
"input batch size for testing (default:
64
)"
,
help
=
"input batch size for testing (default:
128
)"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max-seq-len"
,
"--max-seq-len"
,
type
=
int
,
type
=
int
,
default
=
32
,
default
=
64
,
metavar
=
"N"
,
metavar
=
"N"
,
help
=
"maximum sequence length (default: 32)"
,
help
=
"maximum sequence length (default: 32)"
,
)
)
...
@@ -416,6 +431,12 @@ def encoder_parser(args):
...
@@ -416,6 +431,12 @@ def encoder_parser(args):
default
=
False
,
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-sp"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enable sequence parallelism."
"--enable-sp"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enable sequence parallelism."
)
)
...
@@ -426,7 +447,8 @@ def encoder_parser(args):
...
@@ -426,7 +447,8 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
"""Encoder unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -437,29 +459,48 @@ class TestEncoder(unittest.TestCase):
...
@@ -437,29 +459,48 @@ class TestEncoder(unittest.TestCase):
def
test_te_bf16
(
self
):
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
"""Test Transformer Engine with BF16"""
actual
=
train_and_evaluate
(
self
.
args
)
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.45
and
actual
[
1
]
>
0.79
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_
reason
)
def
test_te_fp8
(
self
):
def
test_te_
mx
fp8
(
self
):
"""Test Transformer Engine with FP8"""
"""Test Transformer Engine with
MX
FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
45
5
and
actual
[
1
]
>
0.7
85
assert
actual
[
0
]
<
0.5
0
and
actual
[
1
]
>
0.7
6
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16_sp
(
self
):
def
test_te_bf16_
with_
sp
(
self
):
"""Test Transformer Engine with BF16 + SP"""
"""Test Transformer Engine with BF16 + SP"""
self
.
args
.
enable_sp
=
True
self
.
args
.
enable_sp
=
True
actual
=
train_and_evaluate
(
self
.
args
)
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.45
and
actual
[
1
]
>
0.79
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8_with_sp
(
self
):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self
.
args
.
enable_sp
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_
reason
)
def
test_te_fp8_sp
(
self
):
def
test_te_
mx
fp8_
with_
sp
(
self
):
"""Test Transformer Engine with FP8 + SP"""
"""Test Transformer Engine with
MX
FP8 + SP"""
self
.
args
.
enable_sp
=
True
self
.
args
.
enable_sp
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
45
5
and
actual
[
1
]
>
0.7
85
assert
actual
[
0
]
<
0.5
0
and
actual
[
1
]
>
0.7
6
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_multigpu_encoder.py
View file @
a207db1d
...
@@ -19,10 +19,11 @@ from flax.training import train_state
...
@@ -19,10 +19,11 @@ from flax.training import train_state
from
jax.experimental
import
mesh_utils
from
jax.experimental
import
mesh_utils
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
import
transformer_engine.jax
as
te
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
from
common
import
is_bf16_supported
DEVICE_DP_AXIS
=
"data"
DEVICE_DP_AXIS
=
"data"
PARAMS_KEY
=
"params"
PARAMS_KEY
=
"params"
...
@@ -198,9 +199,8 @@ def get_datasets(max_seq_len):
...
@@ -198,9 +199,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
)
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
...
@@ -243,6 +243,18 @@ def train_and_evaluate(args):
...
@@ -243,6 +243,18 @@ def train_and_evaluate(args):
num_gpu
=
jax
.
local_device_count
()
num_gpu
=
jax
.
local_device_count
()
assert
args
.
batch_size
%
num_gpu
==
0
,
f
"Batch size needs to be multiple of
{
num_gpu
}
"
assert
args
.
batch_size
%
num_gpu
==
0
,
f
"Batch size needs to be multiple of
{
num_gpu
}
"
assert
args
.
test_batch_size
%
num_gpu
==
0
,
f
"Test batch size needs to be multiple of
{
num_gpu
}
"
assert
args
.
test_batch_size
%
num_gpu
==
0
,
f
"Test batch size needs to be multiple of
{
num_gpu
}
"
if
args
.
fp8_recipe
==
"MXFP8BlockScaling"
:
assert
(
args
.
batch_size
/
num_gpu
%
32
==
0
),
"Batch size needs to be multiple of 32 for MXFP8"
assert
(
args
.
test_batch_size
/
num_gpu
%
32
==
0
),
"Test batch size needs to be multiple of 32 for MXFP8"
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu
,))
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu
,))
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,))
as
mesh
:
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,))
as
mesh
:
...
@@ -257,7 +269,9 @@ def train_and_evaluate(args):
...
@@ -257,7 +269,9 @@ def train_and_evaluate(args):
label_shape
=
[
args
.
batch_size
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
with
te
.
fp8_autocast
(
args
.
use_fp8
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
None
,
None
,
None
)
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
None
,
None
,
None
),
):
):
encoder
=
Net
(
num_embed
)
encoder
=
Net
(
num_embed
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
@@ -344,16 +358,16 @@ def encoder_parser(args):
...
@@ -344,16 +358,16 @@ def encoder_parser(args):
parser
.
add_argument
(
parser
.
add_argument
(
"--batch-size"
,
"--batch-size"
,
type
=
int
,
type
=
int
,
default
=
128
,
default
=
256
,
metavar
=
"N"
,
metavar
=
"N"
,
help
=
"input batch size for training (default:
128
)"
,
help
=
"input batch size for training (default:
256
)"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--test-batch-size"
,
"--test-batch-size"
,
type
=
int
,
type
=
int
,
default
=
128
,
default
=
256
,
metavar
=
"N"
,
metavar
=
"N"
,
help
=
"input batch size for testing (default:
128
)"
,
help
=
"input batch size for testing (default:
256
)"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max-seq-len"
,
"--max-seq-len"
,
...
@@ -389,6 +403,12 @@ def encoder_parser(args):
...
@@ -389,6 +403,12 @@ def encoder_parser(args):
default
=
False
,
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
return
parser
.
parse_args
(
args
)
return
parser
.
parse_args
(
args
)
...
@@ -396,7 +416,8 @@ def encoder_parser(args):
...
@@ -396,7 +416,8 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
"""Encoder unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -407,14 +428,23 @@ class TestEncoder(unittest.TestCase):
...
@@ -407,14 +428,23 @@ class TestEncoder(unittest.TestCase):
def
test_te_bf16
(
self
):
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
"""Test Transformer Engine with BF16"""
actual
=
train_and_evaluate
(
self
.
args
)
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
assert
actual
[
0
]
<
0.535
and
actual
[
1
]
>
0.73
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.535
and
actual
[
1
]
>
0.73
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_
reason
)
def
test_te_fp8
(
self
):
def
test_te_
mx
fp8
(
self
):
"""Test Transformer Engine with FP8"""
"""Test Transformer Engine with
MX
FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.5
0
and
actual
[
1
]
>
0.7
6
assert
actual
[
0
]
<
0.5
35
and
actual
[
1
]
>
0.7
3
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_multiprocessing_encoder.py
View file @
a207db1d
...
@@ -21,9 +21,15 @@ from flax.training import train_state
...
@@ -21,9 +21,15 @@ from flax.training import train_state
from
jax.experimental
import
mesh_utils
from
jax.experimental
import
mesh_utils
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
common
import
is_bf16_supported
,
is_fp8_supported
from
common
import
(
is_bf16_supported
,
is_fp8_supported
,
is_mxfp8_supported
,
get_fp8_recipe_from_name_string
,
)
import
transformer_engine.jax
as
te
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
os
.
environ
[
"CUDA_DEVICE_ORDER"
]
=
"PCI_BUS_ID"
os
.
environ
[
"CUDA_DEVICE_ORDER"
]
=
"PCI_BUS_ID"
...
@@ -298,9 +304,8 @@ def get_datasets(max_seq_len):
...
@@ -298,9 +304,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
)
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
...
@@ -359,10 +364,16 @@ def train_and_evaluate(args):
...
@@ -359,10 +364,16 @@ def train_and_evaluate(args):
num_gpu_dp
=
1
num_gpu_dp
=
1
num_gpu_tp
=
1
num_gpu_tp
=
1
assert
args
.
batch_size
%
num_gpu_dp
==
0
,
f
"Batch size needs to be multiple of
{
num_gpu_dp
}
"
if
args
.
fp8_recipe
==
"MXFP8BlockScaling"
:
assert
args
.
batch_size
%
32
==
0
,
"Batch size needs to be multiple of 32 for MXFP8"
assert
(
assert
(
args
.
test_batch_size
%
num_gpu_dp
==
0
args
.
test_batch_size
%
32
==
0
),
f
"Test batch size needs to be multiple of
{
num_gpu_dp
}
"
),
"Test batch size needs to be multiple of 32 for MXFP8"
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
with
jax
.
sharding
.
Mesh
(
with
jax
.
sharding
.
Mesh
(
...
@@ -379,7 +390,9 @@ def train_and_evaluate(args):
...
@@ -379,7 +390,9 @@ def train_and_evaluate(args):
label_shape
=
[
args
.
batch_size
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
with
te
.
fp8_autocast
(
args
.
use_fp8
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
)
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
):
):
encoder
=
Net
(
num_embed
)
encoder
=
Net
(
num_embed
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
@@ -482,23 +495,23 @@ def encoder_parser(args):
...
@@ -482,23 +495,23 @@ def encoder_parser(args):
parser
.
add_argument
(
parser
.
add_argument
(
"--batch-size"
,
"--batch-size"
,
type
=
int
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
metavar
=
"N"
,
help
=
"input batch size for training (default:
64
)"
,
help
=
"input batch size for training (default:
128
)"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--test-batch-size"
,
"--test-batch-size"
,
type
=
int
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
metavar
=
"N"
,
help
=
"input batch size for testing (default:
64
)"
,
help
=
"input batch size for testing (default:
128
)"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--max-seq-len"
,
"--max-seq-len"
,
type
=
int
,
type
=
int
,
default
=
32
,
default
=
64
,
metavar
=
"N"
,
metavar
=
"N"
,
help
=
"maximum sequence length (default:
32
)"
,
help
=
"maximum sequence length (default:
64
)"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--epochs"
,
"--epochs"
,
...
@@ -527,6 +540,12 @@ def encoder_parser(args):
...
@@ -527,6 +540,12 @@ def encoder_parser(args):
default
=
False
,
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--coordinator-address"
,
"--coordinator-address"
,
type
=
str
,
type
=
str
,
...
@@ -554,37 +573,46 @@ def encoder_parser(args):
...
@@ -554,37 +573,46 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
"""Encoder unittests"""
gpu_has_fp8
=
is_fp8_supported
()
def
exec
(
self
,
use_fp8
,
fp8_recipe
):
gpu_has_bf16
=
is_bf16_supported
()
def
exec
(
self
,
use_fp8
):
"""Run 3 epochs for testing"""
"""Run 3 epochs for testing"""
args
=
encoder_parser
([])
args
=
encoder_parser
([])
num_gpu
=
self
.
num_process
num_gpu
=
self
.
num_process
tp_size
=
2
if
num_gpu
>
1
and
num_gpu
%
2
==
0
else
1
tp_size
=
2
if
num_gpu
>
1
and
num_gpu
%
2
==
0
else
1
dp_size
=
num_gpu
//
tp_size
dp_size
=
num_gpu
//
tp_size
batch_size
=
64
//
dp_size
assert
args
.
batch_size
%
dp_size
==
0
,
f
"Batch size needs to be multiple of
{
dp_size
}
"
batch_size
=
args
.
batch_size
//
dp_size
args
.
use_fp8
=
use_fp8
args
.
use_fp8
=
use_fp8
args
.
batch_size
=
batch_size
args
.
batch_size
=
batch_size
args
.
test_batch_size
=
batch_size
args
.
test_batch_size
=
batch_size
args
.
num_process
=
num_gpu
args
.
num_process
=
num_gpu
args
.
process_id
=
self
.
process_id
args
.
process_id
=
self
.
process_id
args
.
fp8_recipe
=
fp8_recipe
return
train_and_evaluate
(
args
)
return
train_and_evaluate
(
args
)
@
unittest
.
skipIf
(
not
gpu_has_bf16
,
"Device compute capability 8.0+ is required for BF16"
)
@
unittest
.
skipIf
(
not
is_bf16_supported
()
,
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16
(
self
):
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
"""Test Transformer Engine with BF16"""
result
=
self
.
exec
(
False
)
result
=
self
.
exec
(
False
,
None
)
assert
result
[
0
]
<
0.45
and
result
[
1
]
>
0.79
assert
result
[
0
]
<
0.505
and
result
[
1
]
>
0.755
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
"Device compute capability 9.0+ is required for FP8"
)
@
unittest
.
skipIf
(
def
test_te_fp8
(
self
):
not
is_fp8_supported
(),
"Device compute capability 9.0+ is required for DelayedScaling FP8"
"""Test Transformer Engine with FP8"""
)
result
=
self
.
exec
(
True
)
def
test_te_delayed_scaling_fp8
(
self
):
assert
result
[
0
]
<
0.455
and
result
[
1
]
>
0.79
"""Test Transformer Engine with DelayedScaling FP8"""
result
=
self
.
exec
(
True
,
"DelayedScaling"
)
assert
result
[
0
]
<
0.505
and
result
[
1
]
>
0.755
@
unittest
.
skipIf
(
not
is_mxfp8_supported
(),
"Device compute capability 10.0+ is required for MXFP8"
)
def
test_te_mxfp8
(
self
):
"""Test Transformer Engine with MXFP8"""
result
=
self
.
exec
(
True
,
"MXFP8BlockScaling"
)
assert
result
[
0
]
<
0.505
and
result
[
1
]
>
0.754
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_single_gpu_encoder.py
View file @
a207db1d
...
@@ -16,10 +16,11 @@ from datasets import load_dataset
...
@@ -16,10 +16,11 @@ from datasets import load_dataset
from
flax
import
linen
as
nn
from
flax
import
linen
as
nn
from
flax.training
import
train_state
from
flax.training
import
train_state
from
common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
import
transformer_engine.jax
as
te
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
from
common
import
is_bf16_supported
PARAMS_KEY
=
"params"
PARAMS_KEY
=
"params"
DROPOUT_KEY
=
"dropout"
DROPOUT_KEY
=
"dropout"
...
@@ -59,7 +60,7 @@ class Net(nn.Module):
...
@@ -59,7 +60,7 @@ class Net(nn.Module):
return
x
return
x
@
partial
(
jax
.
jit
)
@
jax
.
jit
def
train_step
(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
):
def
train_step
(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
):
"""Computes gradients, loss and accuracy for a single batch."""
"""Computes gradients, loss and accuracy for a single batch."""
...
@@ -195,9 +196,8 @@ def get_datasets(max_seq_len):
...
@@ -195,9 +196,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
)
def
train_and_evaluate
(
args
):
def
train_and_evaluate
(
args
):
...
@@ -214,7 +214,12 @@ def train_and_evaluate(args):
...
@@ -214,7 +214,12 @@ def train_and_evaluate(args):
mask_shape
=
[
args
.
batch_size
,
1
,
args
.
max_seq_len
,
args
.
max_seq_len
]
mask_shape
=
[
args
.
batch_size
,
1
,
args
.
max_seq_len
,
args
.
max_seq_len
]
label_shape
=
[
args
.
batch_size
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
):
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
):
encoder
=
Net
(
num_embed
)
encoder
=
Net
(
num_embed
)
# We use nn.Embed, thus inputs need to be in int
# We use nn.Embed, thus inputs need to be in int
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
@@ -309,6 +314,12 @@ def encoder_parser(args):
...
@@ -309,6 +314,12 @@ def encoder_parser(args):
default
=
False
,
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
return
parser
.
parse_args
(
args
)
return
parser
.
parse_args
(
args
)
...
@@ -316,7 +327,8 @@ def encoder_parser(args):
...
@@ -316,7 +327,8 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
"""Encoder unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -329,10 +341,19 @@ class TestEncoder(unittest.TestCase):
...
@@ -329,10 +341,19 @@ class TestEncoder(unittest.TestCase):
actual
=
train_and_evaluate
(
self
.
args
)
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.45
and
actual
[
1
]
>
0.79
assert
actual
[
0
]
<
0.45
and
actual
[
1
]
>
0.79
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_fp8
(
self
):
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with FP8"""
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.455
and
actual
[
1
]
>
0.79
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8
(
self
):
"""Test Transformer Engine with MXFP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.455
and
actual
[
1
]
>
0.79
assert
actual
[
0
]
<
0.455
and
actual
[
1
]
>
0.79
...
...
examples/jax/mnist/test_single_gpu_mnist.py
View file @
a207db1d
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
import
argparse
import
argparse
import
unittest
import
unittest
from
functools
import
partial
from
functools
import
partial
import
sys
from
pathlib
import
Path
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
...
@@ -16,6 +18,11 @@ from flax.training import train_state
...
@@ -16,6 +18,11 @@ from flax.training import train_state
import
transformer_engine.jax
as
te
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
DIR
=
str
(
Path
(
__file__
).
resolve
().
parents
[
1
])
sys
.
path
.
append
(
str
(
DIR
))
from
encoder.common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
IMAGE_H
=
28
IMAGE_H
=
28
IMAGE_W
=
28
IMAGE_W
=
28
...
@@ -37,6 +44,7 @@ class Net(nn.Module):
...
@@ -37,6 +44,7 @@ class Net(nn.Module):
else
:
else
:
nn_Dense
=
nn
.
Dense
nn_Dense
=
nn
.
Dense
# dtype is used for param init in TE but computation in Linen.nn
# dtype is used for param init in TE but computation in Linen.nn
dtype
=
jnp
.
float32
if
self
.
use_te
else
jnp
.
bfloat16
dtype
=
jnp
.
float32
if
self
.
use_te
else
jnp
.
bfloat16
x
=
nn
.
Conv
(
features
=
32
,
kernel_size
=
(
3
,
3
),
strides
=
1
,
dtype
=
jnp
.
bfloat16
)(
x
)
x
=
nn
.
Conv
(
features
=
32
,
kernel_size
=
(
3
,
3
),
strides
=
1
,
dtype
=
jnp
.
bfloat16
)(
x
)
...
@@ -50,8 +58,8 @@ class Net(nn.Module):
...
@@ -50,8 +58,8 @@ class Net(nn.Module):
x
=
nn_Dense
(
features
=
128
,
dtype
=
dtype
)(
x
)
x
=
nn_Dense
(
features
=
128
,
dtype
=
dtype
)(
x
)
x
=
nn
.
relu
(
x
)
x
=
nn
.
relu
(
x
)
x
=
nn
.
Dropout
(
rate
=
0.5
)(
x
,
deterministic
=
disable_dropout
)
x
=
nn
.
Dropout
(
rate
=
0.5
)(
x
,
deterministic
=
disable_dropout
)
x
=
nn_Dense
(
features
=
16
,
dtype
=
dtype
)(
x
)
x
=
nn_Dense
(
features
=
32
,
dtype
=
dtype
)(
x
)
x
=
nn_Dense
(
features
=
10
,
dtype
=
dtype
)(
x
)
x
=
nn_Dense
(
features
=
32
,
dtype
=
dtype
)(
x
)
assert
x
.
dtype
==
jnp
.
bfloat16
assert
x
.
dtype
==
jnp
.
bfloat16
return
x
return
x
...
@@ -62,7 +70,7 @@ def apply_model(state, images, labels, var_collect, rngs=None):
...
@@ -62,7 +70,7 @@ def apply_model(state, images, labels, var_collect, rngs=None):
def
loss_fn
(
var_collect
,
disable_dropout
=
False
):
def
loss_fn
(
var_collect
,
disable_dropout
=
False
):
logits
=
state
.
apply_fn
(
var_collect
,
images
,
disable_dropout
,
rngs
=
rngs
)
logits
=
state
.
apply_fn
(
var_collect
,
images
,
disable_dropout
,
rngs
=
rngs
)
one_hot
=
jax
.
nn
.
one_hot
(
labels
,
10
)
one_hot
=
jax
.
nn
.
one_hot
(
labels
,
32
)
loss
=
jnp
.
mean
(
optax
.
softmax_cross_entropy
(
logits
=
logits
,
labels
=
one_hot
))
loss
=
jnp
.
mean
(
optax
.
softmax_cross_entropy
(
logits
=
logits
,
labels
=
one_hot
))
return
loss
,
logits
return
loss
,
logits
...
@@ -153,7 +161,7 @@ def get_datasets():
...
@@ -153,7 +161,7 @@ def get_datasets():
def
check_fp8
(
state
,
var_collect
,
input_shape
,
label_shape
):
def
check_fp8
(
state
,
var_collect
,
input_shape
,
label_shape
):
"Check if model includes FP8."
"Check if model includes FP8."
assert
"f8_"
in
str
(
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
apply_model
)(
jax
.
make_jaxpr
(
apply_model
)(
state
,
state
,
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
),
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
),
...
@@ -161,6 +169,7 @@ def check_fp8(state, var_collect, input_shape, label_shape):
...
@@ -161,6 +169,7 @@ def check_fp8(state, var_collect, input_shape, label_shape):
var_collect
,
var_collect
,
)
)
)
)
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
train_and_evaluate
(
args
):
def
train_and_evaluate
(
args
):
...
@@ -179,7 +188,12 @@ def train_and_evaluate(args):
...
@@ -179,7 +188,12 @@ def train_and_evaluate(args):
input_shape
=
[
args
.
batch_size
,
IMAGE_H
,
IMAGE_W
,
IMAGE_C
]
input_shape
=
[
args
.
batch_size
,
IMAGE_H
,
IMAGE_W
,
IMAGE_C
]
label_shape
=
[
args
.
batch_size
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
):
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
):
cnn
=
Net
(
args
.
use_te
)
cnn
=
Net
(
args
.
use_te
)
var_collect
=
cnn
.
init
(
init_rngs
,
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
))
var_collect
=
cnn
.
init
(
init_rngs
,
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
))
tx
=
optax
.
sgd
(
args
.
lr
,
args
.
momentum
)
tx
=
optax
.
sgd
(
args
.
lr
,
args
.
momentum
)
...
@@ -276,6 +290,12 @@ def mnist_parser(args):
...
@@ -276,6 +290,12 @@ def mnist_parser(args):
"It also enables Transformer Engine implicitly."
"It also enables Transformer Engine implicitly."
),
),
)
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use-te"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use Transformer Engine"
"--use-te"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use Transformer Engine"
)
)
...
@@ -286,7 +306,8 @@ def mnist_parser(args):
...
@@ -286,7 +306,8 @@ def mnist_parser(args):
class
TestMNIST
(
unittest
.
TestCase
):
class
TestMNIST
(
unittest
.
TestCase
):
"""MNIST unittests"""
"""MNIST unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
...
@@ -298,13 +319,14 @@ class TestMNIST(unittest.TestCase):
...
@@ -298,13 +319,14 @@ class TestMNIST(unittest.TestCase):
"""Check If loss and accuracy match target"""
"""Check If loss and accuracy match target"""
desired_traing_loss
=
0.055
desired_traing_loss
=
0.055
desired_traing_accuracy
=
0.98
desired_traing_accuracy
=
0.98
desired_test_loss
=
0.04
desired_test_loss
=
0.04
5
desired_test_accuracy
=
0.098
desired_test_accuracy
=
0.098
assert
actual
[
0
]
<
desired_traing_loss
assert
actual
[
0
]
<
desired_traing_loss
assert
actual
[
1
]
>
desired_traing_accuracy
assert
actual
[
1
]
>
desired_traing_accuracy
assert
actual
[
2
]
<
desired_test_loss
assert
actual
[
2
]
<
desired_test_loss
assert
actual
[
3
]
>
desired_test_accuracy
assert
actual
[
3
]
>
desired_test_accuracy
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16
(
self
):
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
"""Test Transformer Engine with BF16"""
self
.
args
.
use_te
=
True
self
.
args
.
use_te
=
True
...
@@ -312,10 +334,19 @@ class TestMNIST(unittest.TestCase):
...
@@ -312,10 +334,19 @@ class TestMNIST(unittest.TestCase):
actual
=
train_and_evaluate
(
self
.
args
)
actual
=
train_and_evaluate
(
self
.
args
)
self
.
verify
(
actual
)
self
.
verify
(
actual
)
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_fp8
(
self
):
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with FP8"""
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
self
.
verify
(
actual
)
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8
(
self
):
"""Test Transformer Engine with MXFP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
actual
=
train_and_evaluate
(
self
.
args
)
self
.
verify
(
actual
)
self
.
verify
(
actual
)
...
...
qa/L0_jax_unittest/test.sh
View file @
a207db1d
...
@@ -20,16 +20,15 @@ pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
...
@@ -20,16 +20,15 @@ pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3
install
pytest
==
8.2.1
||
error_exit
"Failed to install pytest"
pip3
install
pytest
==
8.2.1
||
error_exit
"Failed to install pytest"
:
${
TE_PATH
:
=/opt/transformerengine
}
:
${
TE_PATH
:
=/opt/transformerengine
}
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax
-k
'not distributed'
--ignore
=
$TE_PATH
/tests/jax/test_
praxis_lay
er
s
.py
||
test_fail
"test
_praxis_layers.py
"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax
-k
'not distributed'
--ignore
=
$TE_PATH
/tests/jax/test_
help
er.py
||
test_fail
"test
s/jax/*not_distributed_*
"
# Test without custom calls
# Test without custom calls
NVTE_CUSTOM_CALLS_RE
=
""
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax/test_custom_call_compute.py
||
test_fail
"test_custom_call_compute.py"
NVTE_CUSTOM_CALLS_RE
=
""
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax/test_custom_call_compute.py
||
test_fail
"test_custom_call_compute.py
without TE custom calls
"
pip3
install
-r
$TE_PATH
/examples/jax/mnist/requirements.txt
||
error_exit
"Failed to install mnist requirements"
pip3
install
-r
$TE_PATH
/examples/jax/mnist/requirements.txt
||
error_exit
"Failed to install mnist requirements"
pip3
install
-r
$TE_PATH
/examples/jax/encoder/requirements.txt
||
error_exit
"Failed to install encoder requirements"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/mnist
||
test_fail
"test_mnist.py"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/mnist
||
test_fail
"test_mnist.py"
pip3
install
-r
$TE_PATH
/examples/jax/encoder/requirements.txt
||
error_exit
"Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py"
...
...
qa/L0_pytorch_unittest/test.sh
View file @
a207db1d
...
@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail
...
@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_parallel_cross_entropy.py
||
test_fail
"test_parallel_cross_entropy.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_parallel_cross_entropy.py
||
test_fail
"test_parallel_cross_entropy.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cpu_offloading.py
||
test_fail
"test_cpu_offloading.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cpu_offloading.py
||
test_fail
"test_cpu_offloading.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn.py
||
test_fail
"test_fused_attn.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn.py
||
test_fail
"test_fused_attn.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_
paged_attn
.py
||
test_fail
"test_
paged_attn
.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_
kv_cache
.py
||
test_fail
"test_
kv_cache
.py"
if
[
"
$RET
"
-ne
0
]
;
then
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
echo
"Error in the following test cases:
$FAILED_CASES
"
...
...
qa/L2_jax_unittest/test.sh
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set
-xe
pip
install
"nltk>=3.8.2"
pip
install
pytest
==
8.2.1
:
${
TE_PATH
:
=/opt/transformerengine
}
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax
-k
'not distributed'
--ignore
=
$TE_PATH
/tests/jax/test_praxis_layers.py
# Test without custom calls
NVTE_JAX_UNITTEST_LEVEL
=
"L2"
NVTE_CUSTOM_CALLS_RE
=
""
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax/test_custom_call_compute.py
pip
install
-r
$TE_PATH
/examples/jax/mnist/requirements.txt
pip
install
-r
$TE_PATH
/examples/jax/encoder/requirements.txt
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
tests/jax/distributed_test_base.py
View file @
a207db1d
...
@@ -82,7 +82,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
...
@@ -82,7 +82,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'i32[1024]{0}',
'i32[1024]{0}',
'bf16[1024,1024]{0}'
'bf16[1024,1024]{0}'
"""
"""
match
=
re
.
search
(
r
"(i|f)(\d+).*\[([0-9,]*)\]"
,
t
)
match
=
re
.
search
(
r
"(i|f
|u
)(\d+).*\[([0-9,]*)\]"
,
t
)
_
,
bits_of_type
,
shape
=
match
.
groups
()
_
,
bits_of_type
,
shape
=
match
.
groups
()
bytes_of_type
=
int
(
bits_of_type
)
//
8
bytes_of_type
=
int
(
bits_of_type
)
//
8
if
shape
==
""
:
if
shape
==
""
:
...
...
tests/jax/test_custom_call_compute.py
View file @
a207db1d
...
@@ -2,31 +2,40 @@
...
@@ -2,31 +2,40 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
from
contextlib
import
nullcontext
from
typing
import
Callable
,
List
,
Sequence
,
Union
import
os
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
numpy
as
np
import
pytest
import
pytest
from
jax
import
jit
,
value_and_grad
from
jax
import
jit
,
value_and_grad
from
flax
import
linen
as
nn
from
functools
import
reduce
import
operator
from
utils
import
assert_allclose
,
assert_tree_like_allclose
from
transformer_engine.jax.dot
import
type_safe_dot_general
,
dequantize
,
quantize
from
utils
import
(
from
transformer_engine.jax.fp8
import
FP8MetaPackage
,
FP8Helper
,
is_fp8_available
assert_allclose
,
from
transformer_engine.jax.layernorm
import
layernorm
,
layernorm_fp8_dot
assert_tree_like_allclose
,
from
transformer_engine.jax.layernorm_mlp
import
activation_lu
,
fused_layernorm_fp8_mlp
pytest_parametrize_wrapper
,
from
transformer_engine.jax.cpp_extensions.activation
import
_jax_act_lu
)
from
transformer_engine.jax.cpp_extensions.transpose
import
(
from
transformer_engine.jax.layernorm
import
layernorm
_jax_transpose
,
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
_jax_cast_transpose
,
_jax_dbias_cast_transpose
,
from
transformer_engine.jax.cpp_extensions.activation
import
_jax_act_lu
,
_jax_quantize_dact_dbias
from
transformer_engine.jax.cpp_extensions.normalization
import
_jax_layernorm
,
_jax_rmsnorm
from
transformer_engine.jax.cpp_extensions.quantization
import
(
_jax_quantize
,
_jax_quantize_dbias
,
)
)
from
transformer_engine.jax.cpp_extensions.quantization
import
_jax_cast_fp8
from
transformer_engine.jax
import
cpp_extensions
as
tex
from
transformer_engine.jax
import
cpp_extensions
as
tex
from
transformer_engine.jax.quantize
import
(
DelayedScaleQuantizer
,
ScaledTensor
,
ScalingMode
,
QuantizerFactory
,
QuantizeAxis
,
)
from
transformer_engine.jax.quantize
import
helper
from
transformer_engine.jax.activation
import
activation
from
transformer_engine.jax.dense
import
dense
,
grouped_dense
from
transformer_engine.jax.layernorm_dense
import
layernorm_dense
from
transformer_engine.jax.quantize
import
ScaledTensor1x
,
ScaledTensor2x
GEMM_CASES
=
[
GEMM_CASES
=
[
(
256
,
256
,
512
),
(
256
,
256
,
512
),
...
@@ -36,844 +45,1195 @@ GEMM_CASES = [
...
@@ -36,844 +45,1195 @@ GEMM_CASES = [
(
2048
,
1024
,
1024
),
(
2048
,
1024
,
1024
),
]
]
FP8_COMPUTE_TYPE
=
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
]
FP8_COMPUTE_TYPE
=
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
]
LN_CASES
=
[(
512
,
1024
)]
LN_CASES
=
[(
256
,
128
),
(
128
,
256
)]
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
is_fp8_supported
,
reason
=
is_fp8_available
()
is_fp8_supported
,
reason
=
helper
.
is_fp8_available
()
is_mxfp8_supported
,
reason
=
helper
.
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
class
TestFP8Dot
:
supported_scaling_modes
=
[]
""" Find supported scaling modes"""
@
staticmethod
if
is_fp8_supported
:
def
_generate_fp8_meta
():
supported_scaling_modes
.
append
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
fp8_dtype_list
=
[
FP8Helper
.
FWD_DTYPE
,
FP8Helper
.
FWD_DTYPE
,
FP8Helper
.
BWD_DTYPE
]
if
is_mxfp8_supported
:
amax_list
=
[
supported_scaling_modes
.
append
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
def
is_shape_supported_by_mxfp8
(
input_shape
):
]
try
:
scale_list
=
[
if
isinstance
(
input_shape
,
type
(
pytest
.
param
(
0
))):
jnp
.
ones
((
1
,),
jnp
.
float32
),
input_shape
=
input_shape
.
values
[
0
]
jnp
.
ones
((
1
,),
jnp
.
float32
),
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
get_scale_shape_2x
(
input_shape
)
jnp
.
ones
((
1
,),
jnp
.
float32
),
return
True
]
except
:
return
fp8_dtype_list
,
amax_list
,
scale_list
# get_scale_shapes will raise an exception if the shape is not supported
return
False
def
assert_bitwise_scaled_tensors
(
a
:
ScaledTensor
,
b
:
ScaledTensor
):
if
isinstance
(
a
,
ScaledTensor1x
)
and
isinstance
(
b
,
ScaledTensor1x
):
assert_allclose
(
a
.
data
,
b
.
data
)
assert_allclose
(
a
.
scale_inv
.
astype
(
jnp
.
uint8
),
b
.
scale_inv
.
astype
(
jnp
.
uint8
))
elif
isinstance
(
a
,
ScaledTensor2x
)
and
isinstance
(
b
,
ScaledTensor2x
):
assert_bitwise_scaled_tensors
(
a
.
rowwise_tensor
,
b
.
rowwise_tensor
)
assert_bitwise_scaled_tensors
(
a
.
colwise_tensor
,
b
.
colwise_tensor
)
else
:
pytest
.
fail
(
"Unsupported input types"
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_qdq
(
self
):
FP8_E4M3_MAX
=
(
jnp
.
finfo
(
jnp
.
float8_e4m3fn
).
max
).
astype
(
jnp
.
float32
)
x
=
jnp
.
asarray
([[
-
1
,
0.1
],
[
2
,
3
]],
jnp
.
float32
)
amax
=
jnp
.
max
(
jnp
.
abs
(
x
)).
reshape
(
1
)
scale
=
jnp
.
asarray
(
FP8_E4M3_MAX
/
amax
,
jnp
.
float32
).
reshape
(
1
)
scale_inv
=
(
1
/
scale
).
reshape
(
1
)
y
,
_
=
quantize
(
x
,
q_dtype
=
jnp
.
float8_e4m3fn
,
scale
=
scale
)
def
assert_dequantized_scaled_tensor
(
a
:
ScaledTensor
,
b
:
jnp
.
ndarray
):
z
=
dequantize
(
y
,
dq_dtype
=
jnp
.
float32
,
scale_inv
=
scale_inv
)
if
isinstance
(
a
,
ScaledTensor1x
):
if
a
.
layout
==
"T"
:
b_transpose
=
jnp
.
transpose
(
b
,
(
-
1
,
*
range
(
b
.
ndim
-
1
)))
assert_allclose
(
a
.
dequantize
(),
b_transpose
,
dtype
=
a
.
data
.
dtype
)
else
:
assert_allclose
(
a
.
dequantize
(),
b
,
dtype
=
a
.
data
.
dtype
)
elif
isinstance
(
a
,
ScaledTensor2x
):
assert_dequantized_scaled_tensor
(
a
.
get_rowwise_tensor
(),
b
)
assert_dequantized_scaled_tensor
(
a
.
get_colwise_tensor
(),
b
)
else
:
pytest
.
fail
(
"a must be a ScaledTensor object"
)
ALL_ACTIVATION_SHAPES
=
[(
32
,
64
),
(
16
,
128
,
256
)]
ALL_ACTIVATION_TYPES
=
[
(
"gelu"
,),
(
"gelu"
,
"linear"
),
(
"silu"
,),
(
"silu"
,
"linear"
),
(
"relu"
,),
(
"relu"
,
"linear"
),
(
"quick_gelu"
,),
(
"quick_gelu"
,
"linear"
),
(
"squared_relu"
,),
(
"squared_relu"
,
"linear"
),
]
ACTIVATION_TYPES
=
{
"L0"
:
[
(
"gelu"
,),
(
"gelu"
,
"linear"
),
],
"L2"
:
ALL_ACTIVATION_TYPES
,
}
assert_allclose
(
z
,
x
,
dtype
=
jnp
.
float8_e4m3fn
)
class
TestActivation
:
def
ref_act
(
self
,
x
,
activation_type
):
return
_jax_act_lu
(
x
,
activation_type
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
GEMM_CASES
)
def
value_n_grad_ref_func
(
self
,
x
,
activation_type
):
def
test_forward_bf16
(
self
,
m
,
n
,
k
):
jitted_reference
=
jit
(
value_and_grad
(
lambda
out
:
jnp
.
mean
(
self
.
ref_act
(
out
,
activation_type
)),
(
0
,))
)
return
jitted_reference
(
x
)
def
primitive_func
(
self
,
inputs
,
activation_type
,
quantizer
):
out
=
activation
(
inputs
,
activation_type
=
activation_type
,
quantizer
=
quantizer
)
return
jnp
.
mean
(
out
)
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
(
ALL_ACTIVATION_TYPES
# Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
),
)
def
test_act_grad
(
self
,
shape
,
activation_type
):
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
x
=
jax
.
random
.
uniform
(
key
,
shape
,
jnp
.
float32
)
a
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
b
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
n
),
jnp
.
bfloat16
)
primitive_out
=
type_safe_dot_general
(
a
,
b
)
value_n_grad_primitive_func
=
jit
(
ref_out
=
jnp
.
dot
(
a
,
b
)
value_and_grad
(
self
.
primitive_func
,
(
0
,)),
static_argnums
=
(
1
,)
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
prim_out
,
(
prim_grad
,)
=
value_n_grad_primitive_func
(
x
,
activation_type
,
None
)
ref_out
,
(
ref_grad
,)
=
self
.
value_n_grad_ref_func
(
x
,
activation_type
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
x
.
dtype
)
assert_allclose
(
prim_grad
,
ref_grad
,
dtype
=
x
.
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
GEMM_CASES
)
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
def
test_forward_fp8_randint
(
self
,
m
,
n
,
k
):
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
key
=
jax
.
random
.
PRNGKey
(
0
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
subkeys
=
jax
.
random
.
split
(
key
,
2
)
def
test_act_grad_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
self
.
activation_type
=
activation_type
dtype
=
jnp
.
bfloat16
value_n_grad_primitive_func
=
jit
(
value_and_grad
(
self
.
primitive_func
,
(
0
,)),
static_argnums
=
(
1
,)
)
# TODO(rewang): add float random test
quantizer
=
QuantizerFactory
.
create
(
min_val
,
max_val
=
-
8
,
8
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
a
=
jax
.
random
.
randint
(
subkeys
[
0
],
(
m
,
k
),
min_val
,
max_val
).
astype
(
dtype
)
q_dtype
=
output_type
,
b
=
jax
.
random
.
randint
(
subkeys
[
1
],
(
k
,
n
),
min_val
,
max_val
).
astype
(
dtype
)
q_axis
=
QuantizeAxis
.
ROWWISE
,
)
prim_out
,
(
prim_grad
,)
=
value_n_grad_primitive_func
(
x
,
activation_type
,
quantizer
)
ref_out
,
(
ref_grad
,)
=
self
.
value_n_grad_ref_func
(
x
,
activation_type
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
output_type
)
assert_allclose
(
prim_grad
,
ref_grad
,
dtype
=
output_type
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_act_forward_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_axis
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
self
.
activation_type
=
activation_type
_
,
amax_list
,
scale_list
=
TestFP8Dot
.
_generate_fp8_meta
()
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
fp8_meta_pkg
=
FP8MetaPackage
(
n_quantizers
=
2
,
amax_list
[
0
],
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
scale_list
[
0
],
q_dtype
=
output_type
,
amax_list
[
1
],
q_axis
=
q_axis
,
scale_list
[
1
],
amax_list
[
2
],
scale_list
[
2
],
)
)
primitive_out
=
type_safe_dot_general
(
a
,
b
,
fp8_meta_pkg
)
ref_out
=
jnp
.
dot
(
a
,
b
)
ref
_out
=
ref_out
.
astype
(
jnp
.
float32
)
te
_out
put
=
tex
.
act_lu
(
x
,
activation_type
,
te_quantizer
)
primitive_out
=
primitive_out
.
astype
(
jnp
.
float32
)
jax_output
=
_jax_act_lu
(
x
,
activation_type
,
jax_quantizer
)
assert_
allclose
(
primitive_out
,
ref_out
,
dtype
=
FP8Helper
.
FWD_DTYPE
)
assert_
bitwise_scaled_tensors
(
te_output
,
jax_output
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
GEMM_CASES
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
def
test_grad_bf16
(
self
,
m
,
n
,
k
):
@
pytest_parametrize_wrapper
(
"shape"
,
[(
128
,
128
)])
key
=
jax
.
random
.
PRNGKey
(
0
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
a
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
b
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
n
),
jnp
.
bfloat16
)
def
test_act_forward_with_block_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_axis
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
self
.
activation_type
=
activation_type
def
primitive_func
(
x
,
y
):
quantizer
=
QuantizerFactory
.
create
(
primitive_out
=
type_safe_dot_general
(
x
,
y
)
scaling_mode
=
ScalingMode
.
NVTE_MXFP8_1D_SCALING
,
q_dtype
=
output_type
,
q_axis
=
q_axis
return
jnp
.
mean
(
primitive_out
)
)
def
ref_func
(
x
,
y
):
output
=
tex
.
act_lu
(
x
,
activation_type
,
quantizer
)
return
jnp
.
mean
(
jnp
.
dot
(
x
,
y
)
)
ref_out
=
self
.
ref_act
(
x
,
activation_type
)
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
)
)
assert_dequantized_scaled_tensor
(
output
,
ref_out
)
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
))
primitive_out
,
(
primitive_a_grad
,
primitive_b_grad
)
=
value_n_grad_primitive_func
(
a
,
b
)
NORM_OUTPUT_DTYPES
=
{
ref_out
,
(
ref_a_grad
,
ref_b_grad
)
=
value_n_grad_ref_func
(
a
,
b
)
"L0"
:
[
jnp
.
float8_e4m3fn
],
"L2"
:
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
],
}
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_a_grad
,
ref_a_grad
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_b_grad
,
ref_b_grad
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"n, hidden"
,
LN_CASES
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
GEMM_CASES
)
@
pytest_parametrize_wrapper
(
"inp_dtype"
,
DTYPES
)
def
test_grad_fp8_dot
(
self
,
m
,
n
,
k
):
@
pytest_parametrize_wrapper
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
key
=
jax
.
random
.
PRNGKey
(
0
)
@
pytest_parametrize_wrapper
(
subkeys
=
jax
.
random
.
split
(
key
,
2
)
"zero_centered_gamma"
,
[
pytest
.
param
(
True
,
id
=
"zero_centered"
),
pytest
.
param
(
False
,
id
=
"no_zero_centered"
),
],
)
@
pytest_parametrize_wrapper
(
"epsilon"
,
[
1e-2
,
1e-6
])
class
TestNorm
:
"""
Test transformer_engine.jax.layernorm APIs
"""
def
_test_norm_grad
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
):
def
compute_loss
(
x
):
# Higher precision to compute the loss
x_
=
x
.
astype
(
jnp
.
float32
)
return
jnp
.
mean
(
jnp
.
square
(
x_
)).
astype
(
x
.
dtype
)
def
reference_func
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
):
if
norm_type
==
"rmsnorm"
:
ln_out
,
_
=
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
eps
,
quantizer
)
else
:
ln_out
,
_
,
_
=
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
eps
,
quantizer
)
# if isinstance(ln_out, ScaledTensor):
# ln_out = ln_out.dequantize()
return
ln_out
a
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
)).
astype
(
jnp
.
bfloat16
)
key
=
jax
.
random
.
PRNGKey
(
0
)
b
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
n
)).
astype
(
jnp
.
bfloat16
)
subkeys
=
jax
.
random
.
split
(
key
,
3
)
_
,
amax_list
,
scale_list
=
TestFP8Dot
.
_generate_fp8_meta
()
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
(
n
,
hidden
),
jnp
.
float32
,
-
1
,
1
)
x
=
x
.
astype
(
inp_dtype
)
gamma_range
=
(
-
1
,
1
)
if
zero_centered_gamma
else
(
0
,
2
)
gamma
=
jax
.
random
.
uniform
(
subkeys
[
1
],
(
hidden
,),
jnp
.
float32
,
*
gamma_range
)
gamma
=
jnp
.
asarray
(
gamma
,
inp_dtype
)
if
norm_type
==
"layernorm"
:
beta
=
jax
.
random
.
uniform
(
subkeys
[
2
],
(
hidden
,),
jnp
.
float32
,
-
1
,
1
)
beta
=
jnp
.
asarray
(
beta
,
inp_dtype
)
else
:
beta
=
None
def
primitive_func
(
x
,
y
,
amax_list
,
scale_list
):
jitted_reference
=
jit
(
fp8_meta_pkg
=
FP8MetaPackage
(
value_and_grad
(
amax_list
[
0
],
lambda
x
,
gamma
,
beta
:
compute_loss
(
scale_list
[
0
],
reference_func
(
amax_list
[
1
],
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
epsilon
,
quantizer
=
None
scale_list
[
1
],
)
amax_list
[
2
],
),
scale_list
[
2
],
(
0
,
1
,
2
),
)
)
jitted_primitive
=
jit
(
value_and_grad
(
lambda
x
,
gamma
,
beta
:
compute_loss
(
layernorm
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
epsilon
,
quantizer
)
),
(
0
,
1
,
2
),
)
)
)
primitive_out
=
type_safe_dot_general
(
x
,
y
,
fp8_meta_pkg
)
return
jnp
.
mean
(
primitive_out
)
def
ref_func
(
x
,
y
):
reference_out
,
(
reference_dx
,
reference_dgamma
,
reference_dbeta
)
=
jitted_reference
(
return
jnp
.
mean
(
jnp
.
dot
(
x
,
y
))
x
,
gamma
,
beta
)
primitive_out
,
(
primitive_dx
,
primitive_dgamma
,
primitive_dbeta
)
=
jitted_primitive
(
x
,
gamma
,
beta
)
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
,
3
))
out_dtype
=
inp_dtype
if
quantizer
is
None
else
quantizer
.
q_dtype
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
))
assert_allclose
(
primitive_out
,
reference_out
,
dtype
=
out_dtype
)
assert_allclose
(
primitive_dx
,
reference_dx
,
dtype
=
out_dtype
)
assert_allclose
(
primitive_dgamma
,
reference_dgamma
,
dtype
=
out_dtype
)
if
beta
is
not
None
:
assert_allclose
(
primitive_dbeta
,
reference_dbeta
,
dtype
=
out_dtype
)
ref_out
,
(
ref_a_grad
,
ref_b_grad
)
=
value_n_grad_ref_func
(
a
,
b
)
def
test_norm_grad
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
for
_
in
range
(
3
):
self
.
_test_norm_grad
(
primitive_out
,
(
primitive_a_grad
,
primitive_b_grad
,
amax_list
,
scale_list
)
=
(
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
=
None
value_n_grad_primitive_func
(
a
,
b
,
amax_list
,
scale_list
)
)
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
FP8Helper
.
FWD_DTYPE
)
assert_allclose
(
primitive_a_grad
,
ref_a_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
assert_allclose
(
primitive_b_grad
,
ref_b_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
# No Norm FWD E5M2 in TE backend
"m,n,k"
,
[(
256
,
128
,
512
),
(
16384
,
1024
,
2816
),
(
16384
,
2816
,
1024
),
(
16384
,
1024
,
1024
)]
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_norm_grad_with_delayed_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_axis
):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
q_dtype
=
out_dtype
,
q_axis
=
q_axis
)
)
@
pytest
.
mark
.
parametrize
(
self
.
_test_norm_grad
(
"activation_type"
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
[
(
"gelu"
,),
(
"gelu"
,
"linear"
),
(
"silu"
,),
(
"silu"
,
"linear"
),
(
"relu"
,),
(
"relu"
,
"linear"
),
(
"quick_gelu"
,),
(
"quick_gelu"
,
"linear"
),
(
"squared_relu"
,),
(
"squared_relu"
,
"linear"
),
],
)
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_grad_fused_layernorm_fp8_mlp
(
def
_test_norm_forward
(
self
,
m
,
n
,
k
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
use_bias
:
bool
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
scaling_mode
,
q_axis
,
):
):
"""N/a"""
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
6
)
subkeys
=
jax
.
random
.
split
(
key
,
3
)
a
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
(
n
,
hidden
),
inp_dtype
,
-
1
,
1
)
k1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
len
(
activation_type
),
n
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
k
)
x
=
jnp
.
asarray
(
x
,
inp_dtype
)
k2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
n
,
k
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
n
)
gamma_range
=
(
-
1
,
1
)
if
zero_centered_gamma
else
(
0
,
2
)
s
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
k
,),
jnp
.
bfloat16
)
gamma
=
jax
.
random
.
uniform
(
subkeys
[
1
],
(
hidden
,),
jnp
.
float32
,
*
gamma_range
)
if
use_bias
:
gamma
=
jnp
.
asarray
(
gamma
,
inp_dtype
)
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
),
n
),
jnp
.
bfloat16
)
b2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
k
,),
jnp
.
bfloat16
)
quantizer
,
ref_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
scaling_mode
=
scaling_mode
,
q_dtype
=
out_dtype
,
q_axis
=
q_axis
)
if
norm_type
==
"layernorm"
:
beta
=
jax
.
random
.
uniform
(
subkeys
[
2
],
(
hidden
,),
jnp
.
float32
,
-
1
,
1
)
beta
=
jnp
.
asarray
(
beta
,
inp_dtype
)
output
,
mu
,
rsigma
=
tex
.
layernorm_fwd
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
=
quantizer
)
ref_out
,
ref_mu
,
ref_rsigma
=
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
=
ref_quantizer
)
else
:
else
:
b1
=
None
output
,
rsigma
=
tex
.
rmsnorm_fwd
(
b2
=
None
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
=
quantizer
)
ref_out
,
ref_rsigma
=
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
=
ref_quantizer
)
ref_mu
=
None
def
primitive_func
(
assert_bitwise_scaled_tensors
(
output
,
ref_out
)
x
,
ln_s
,
y
,
z
,
w
,
v
,
amax_list_1
,
amax_list_2
,
scale_list_1
,
scale_list_2
assert_allclose
(
rsigma
,
ref_rsigma
,
dtype
=
inp_dtype
)
if
norm_type
==
"layernorm"
:
assert_allclose
(
mu
,
ref_mu
,
dtype
=
inp_dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
# No Norm FWD E5M2 in TE backend
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_norm_forward_with_delayed_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_axis
):
):
# x is input tensor, matrix 2d
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
# y, z are weights, matrix 2d
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
# out = ((x * y) + w) * z + v
fp8_meta_pkg_1
=
FP8MetaPackage
(
self
.
_test_norm_forward
(
amax_list_1
[
0
],
n
=
n
,
scale_list_1
[
0
],
hidden
=
hidden
,
amax_list_1
[
1
],
norm_type
=
norm_type
,
scale_list_1
[
1
],
zero_centered_gamma
=
zero_centered_gamma
,
amax_list_1
[
2
],
epsilon
=
epsilon
,
scale_list_1
[
2
],
inp_dtype
=
inp_dtype
,
)
out_dtype
=
out_dtype
,
fp8_meta_pkg_2
=
FP8MetaPackage
(
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
amax_list_2
[
0
],
q_axis
=
q_axis
,
scale_list_2
[
0
],
amax_list_2
[
1
],
scale_list_2
[
1
],
amax_list_2
[
2
],
scale_list_2
[
2
],
)
)
return
jnp
.
mean
(
fused_layernorm_fp8_mlp
(
x
,
ln_s
,
None
,
[
y
,
z
],
[
w
,
v
],
[
fp8_meta_pkg_1
,
fp8_meta_pkg_2
],
"rmsnorm"
,
activation_type
=
activation_type
,
use_bias
=
use_bias
,
)
)
def
layernorm_fp8_mlp_ref
(
x
:
jnp
.
ndarray
,
ln_scale
:
jnp
.
ndarray
,
kernel_1
:
jnp
.
ndarray
,
kernel_2
:
jnp
.
ndarray
,
bias_1
:
jnp
.
ndarray
,
bias_2
:
jnp
.
ndarray
,
amax_list_1
:
List
[
jnp
.
ndarray
],
amax_list_2
:
List
[
jnp
.
ndarray
],
scale_list_1
:
List
[
jnp
.
ndarray
],
scale_list_2
:
List
[
jnp
.
ndarray
],
)
->
jnp
.
ndarray
:
x
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
mean2
=
jnp
.
mean
(
jax
.
lax
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)
y
=
jnp
.
asarray
(
x
*
jax
.
lax
.
rsqrt
(
mean2
+
1e-6
),
jnp
.
bfloat16
)
ln_out
=
y
*
ln_scale
ln_out
=
jnp
.
asarray
(
ln_out
,
jnp
.
bfloat16
)
fp8_meta_pkg_1
=
FP8MetaPackage
(
amax_list_1
[
0
],
scale_list_1
[
0
],
amax_list_1
[
1
],
scale_list_1
[
1
],
amax_list_1
[
2
],
scale_list_1
[
2
],
)
linear_1_out
=
type_safe_dot_general
(
ln_out
,
kernel_1
,
fp8_meta_pkg_1
,
((
1
,),
(
0
,)))
if
use_bias
:
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
bias_1_shape
=
(
1
,)
*
(
linear_1_out
.
ndim
-
bias_1
.
ndim
)
+
bias_1
.
shape
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
linear_1_out
+=
jnp
.
reshape
(
bias_1
,
bias_1_shape
)
def
test_norm_forward_with_block_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
):
self
.
_test_norm_forward
(
n
=
n
,
hidden
=
hidden
,
norm_type
=
norm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
inp_dtype
=
inp_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_MXFP8_1D_SCALING
,
q_axis
=
QuantizeAxis
.
ROWWISE_COLWISE
,
)
x
=
_jax_act_lu
(
linear_1_out
,
activation_type
)
fp8_meta_pkg_2
=
FP8MetaPackage
(
QUANTIZE_OUTPUT_DTYPES
=
{
amax_list_2
[
0
],
"L0"
:
[
jnp
.
float8_e4m3fn
],
scale_list_2
[
0
],
"L2"
:
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
],
amax_list_2
[
1
],
}
scale_list_2
[
1
],
amax_list_2
[
2
],
scale_list_2
[
2
],
)
output
=
type_safe_dot_general
(
x
,
kernel_2
,
fp8_meta_pkg_2
,
((
1
,),
(
0
,)))
if
use_bias
:
ALL_QUANTIZE_TEST_SHAPES
=
[
bias_2_shape
=
(
1
,)
*
(
output
.
ndim
-
bias_2
.
ndim
)
+
bias_2
.
shape
(
128
,
128
),
output
+=
jnp
.
reshape
(
bias_2
,
bias_2_shape
)
(
4
,
256
,
512
),
]
return
output
QUANTIZE_TEST_SHAPES
=
{
"L0"
:
[
(
256
,
128
),
(
64
,
16
,
2
,
256
),
],
"L2"
:
ALL_QUANTIZE_TEST_SHAPES
,
}
def
ref_func
(
x
,
ln_s
,
y
,
z
,
w
,
v
,
amax_list_1
,
amax_list_2
,
scale_list_1
,
scale_list_2
):
QUANTIZATION_INPUT_DTYPE
=
{
return
jnp
.
mean
(
"L0"
:
[
jnp
.
bfloat16
],
layernorm_fp8_mlp_ref
(
"L2"
:
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
],
x
,
ln_s
,
y
,
z
,
w
,
v
,
amax_list_1
,
amax_list_2
,
scale_list_1
,
scale_list_2
}
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
]
)
class
TestQuantize
:
"""
Purely quantization related tests that will always test on a wider set of types and shapes
"""
def
test_qdq
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_axis
):
key
=
jax
.
random
.
PRNGKey
(
0
)
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
scaling_mode
,
q_dtype
=
q_dtype
,
q_axis
=
q_axis
,
)
)
value_n_grad_primitive_func
=
jit
(
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
else
1
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
))
for
_
in
range
(
n_iterations
):
x
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
scaled_tensor
=
quantizer
.
quantize
(
x
)
assert_dequantized_scaled_tensor
(
scaled_tensor
,
x
)
def
test_quantize_bitwise
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_axis
):
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
and
not
is_shape_supported_by_mxfp8
(
input_shape
):
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
key
=
jax
.
random
.
PRNGKey
(
0
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
)
)
value_n_grad_ref_func
=
jit
(
value_and_grad
(
ref_func
,
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
)))
_
,
amax_list_1
,
scale_list_1
=
TestFP8Dot
.
_generate_fp8_meta
()
jax_output
=
_jax_quantize
(
input
,
quantizer
=
jax_quantizer
)
_
,
amax_list_2
,
scale_list_2
=
TestFP8Dot
.
_generate_fp8_meta
()
ref_amax_list_1
=
amax_list_1
te_output
=
tex
.
quantize
(
input
,
quantizer
=
te_quantizer
)
ref_scale_list_1
=
scale_list_1
assert_bitwise_scaled_tensors
(
jax_output
,
te_output
)
ref_amax_list_2
=
amax_list_2
ref_scale_list_2
=
scale_list_2
primitive_amax_list_1
=
amax_list_1
primitive_scale_list_1
=
scale_list_1
primitive_amax_list_2
=
amax_list_2
primitive_scale_list_2
=
scale_list_2
primitive_amax_list_1
,
primitive_scale_list_1
,
primitive_amax_list_2
,
primitive_scale_list_2
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
class
TestFusedQuantize
:
# Convert str to index as str is not a valid type for JAX JIT
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
for
_
in
range
(
3
):
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
ref_out
,
(
@
pytest_parametrize_wrapper
(
"input_shape"
,
QUANTIZE_TEST_SHAPES
)
ref_a_grad
,
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
ref_s_grad
,
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
ref_k1_grad
,
def
test_quantize_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
q_axis
):
ref_k2_grad
,
transpose_axis
=
-
1
ref_b1_grad
,
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
and
not
is_shape_supported_by_mxfp8
(
ref_b2_grad
,
input_shape
ref_amax_list_1
,
):
ref_amax_list_2
,
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
ref_scale_list_1
,
ref_scale_list_2
,
key
=
jax
.
random
.
PRNGKey
(
0
)
)
=
value_n_grad_ref_func
(
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
a
,
s
,
jax_quantizer
,
te_quantizer
=
QuantizerFactory
.
create
(
k1
,
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
k2
,
b1
,
b2
,
ref_amax_list_1
,
ref_amax_list_2
,
ref_scale_list_1
,
ref_scale_list_2
,
)
for
_
in
range
(
3
):
primitive_out
,
(
primitive_a_grad
,
primitive_s_grad
,
primitive_k1_grad
,
primitive_k2_grad
,
primitive_b1_grad
,
primitive_b2_grad
,
primitive_amax_list_1
,
primitive_amax_list_2
,
primitive_scale_list_1
,
primitive_scale_list_2
,
)
=
value_n_grad_primitive_func
(
a
,
s
,
k1
,
k2
,
b1
,
b2
,
primitive_amax_list_1
,
primitive_amax_list_2
,
primitive_scale_list_1
,
primitive_scale_list_2
,
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
FP8Helper
.
FWD_DTYPE
)
assert_allclose
(
jnp
.
asarray
(
primitive_a_grad
,
np
.
float32
),
jnp
.
asarray
(
ref_a_grad
,
np
.
float32
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
)
assert_allclose
(
jnp
.
asarray
(
primitive_k1_grad
,
np
.
float32
),
jnp
.
asarray
(
ref_k1_grad
,
np
.
float32
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
)
assert_allclose
(
jnp
.
asarray
(
primitive_s_grad
,
np
.
float32
),
jnp
.
asarray
(
ref_s_grad
,
np
.
float32
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
)
assert_allclose
(
jnp
.
asarray
(
primitive_k2_grad
,
np
.
float32
),
jnp
.
asarray
(
ref_k2_grad
,
np
.
float32
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
)
)
if
use_bias
:
assert_allclose
(
te_output
,
te_dbias
=
jit
(
lambda
input
:
tex
.
quantize_dbias
(
input
,
quantizer
=
te_quantizer
))(
jnp
.
asarray
(
primitive_b2_grad
,
np
.
float32
),
input
jnp
.
asarray
(
ref_b2_grad
,
np
.
float32
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
)
)
assert_allclose
(
jnp
.
asarray
(
primitive_b1_grad
,
np
.
float32
),
jax_output
,
jax_dbias
=
jit
(
jnp
.
asarray
(
ref_b1_grad
,
np
.
float32
),
lambda
input
:
_jax_quantize_dbias
(
dtype
=
FP8Helper
.
BWD_DTYPE
,
input
,
quantizer
=
jax_quantizer
,
)
)
)(
input
)
assert_bitwise_scaled_tensors
(
jax_output
,
te_output
)
@
pytest
.
fixture
(
name
=
"random_inputs"
)
assert_allclose
(
jax_dbias
,
te_dbias
)
def
random_inputs_fixture
(
shape
):
def
_test_quantize_dact_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
activation_type
,
is_dbias
,
q_axis
):
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
4
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
out
=
jax
.
random
.
uniform
(
subkeys
[
0
],
shape
,
jnp
.
bfloat16
,
5
,
8
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
input_shape
,
in_dtype
,
-
1
,
1
)
return
out
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
dz
=
jax
.
random
.
uniform
(
subkeys
[
1
],
input_shape
,
in_dtype
,
-
1
,
1
)
jax_quantizer
,
te_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
)
is_casted_output
=
te_quantizer
is
not
None
class
TestActivationLu
:
te_output
,
te_dbias
=
jit
(
lambda
dz
,
x
:
tex
.
quantize_dact_dbias
(
dz
,
x
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
quantizer
=
te_quantizer
,
)
)(
dz
,
x
)
def
ref_func
(
self
,
x
,
activation_type
):
jax_output
,
jax_dbias
=
jit
(
lambda
dz
,
x
:
_jax_quantize_dact_dbias
(
dz
,
x
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
quantizer
=
jax_quantizer
,
)
)(
dz
,
x
)
def
ref_act_lu
(
inputs
):
if
is_casted_output
:
x
=
_jax_act_lu
(
inputs
,
activation_type
)
assert_bitwise_scaled_tensors
(
jax_output
,
te_output
)
return
jnp
.
mean
(
x
)
else
:
assert_allclose
(
jax_output
,
te_output
)
if
is_dbias
:
assert_allclose
(
jax_dbias
,
te_dbias
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
def
test_quantize_dact_dbias_no_quantization
(
self
,
in_dtype
,
input_shape
,
activation_type
,
is_dbias
,
):
self
.
_test_quantize_dact_dbias
(
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
out_dtype
=
in_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_NO_SCALING
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
q_axis
=
QuantizeAxis
.
ROWWISE
,
)
ref_act_func
=
jit
(
value_and_grad
(
ref_act_lu
,
(
0
,)))
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
return
ref_act_func
(
x
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_quantize_dact_dbias_delayed_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_axis
):
self
.
_test_quantize_dact_dbias
(
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
q_axis
=
q_axis
,
)
def
primitive_func
(
self
,
inputs
):
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
return
jnp
.
mean
(
activation_lu
(
inputs
,
activation_type
=
self
.
activation_type
))
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
[
s
for
s
in
ALL_ACTIVATION_SHAPES
if
is_shape_supported_by_mxfp8
(
s
)]
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_quantize_dact_dbias_mxfp8_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_axis
):
if
reduce
(
operator
.
mul
,
input_shape
[:
-
1
])
%
128
!=
0
or
input_shape
[
-
1
]
%
128
!=
0
:
# TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
# If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX
# implementation in the unsupported cases
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by dact MXFP8 kernel in TE currently"
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
1
,
64
),
(
16
,
64
,
1
,
256
)])
self
.
_test_quantize_dact_dbias
(
@
pytest
.
mark
.
parametrize
(
in_dtype
=
in_dtype
,
"activation_type"
,
input_shape
=
input_shape
,
[
out_dtype
=
out_dtype
,
(
"gelu"
,),
scaling_mode
=
ScalingMode
.
NVTE_MXFP8_1D_SCALING
,
(
"gelu"
,
"linear"
),
activation_type
=
activation_type
,
(
"silu"
,),
is_dbias
=
is_dbias
,
(
"silu"
,
"linear"
),
q_axis
=
q_axis
,
(
"relu"
,),
(
"relu"
,
"linear"
),
(
"quick_gelu"
,),
(
"quick_gelu"
,
"linear"
),
(
"squared_relu"
,),
(
"squared_relu"
,
"linear"
),
],
)
)
def
test_activation_lu
(
self
,
random_inputs
,
activation_type
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
self
.
activation_type
=
activation_type
value_n_grad_primitive_func
=
jit
(
value_and_grad
(
self
.
primitive_func
,
(
0
,)))
prim_out
,
(
prim_grad
,)
=
value_n_grad_primitive_func
(
x
)
class
TestDense
:
ref_out
,
(
ref_grad
,)
=
self
.
ref_func
(
x
,
activation_type
)
def
_ref_gemm_with_jnp_dot
(
self
,
a
,
b
,
layout
):
if
layout
[
0
]
==
"T"
:
a
=
jnp
.
swapaxes
(
a
,
-
1
,
-
2
)
if
layout
[
1
]
==
"T"
:
b
=
jnp
.
swapaxes
(
b
,
-
1
,
-
2
)
return
jnp
.
dot
(
a
,
b
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
x
.
dtype
)
def
_generate_gemm_input
(
self
,
m
,
n
,
k
,
layout
):
assert_allclose
(
prim_grad
,
ref_grad
,
dtype
=
x
.
dtype
)
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
(
m
if
layout
[
0
]
==
"N"
else
k
,
k
if
layout
[
0
]
==
"N"
else
m
),
dtype
=
jnp
.
bfloat16
,
)
/
jnp
.
sqrt
(
k
)
w
=
jax
.
random
.
uniform
(
subkeys
[
1
],
(
k
if
layout
[
1
]
==
"N"
else
n
,
n
if
layout
[
1
]
==
"N"
else
k
),
dtype
=
jnp
.
bfloat16
,
)
/
jnp
.
sqrt
(
n
)
lhs_contracting_dim
=
(
1
,)
if
layout
[
0
]
==
"N"
else
(
0
,)
rhs_contracting_dim
=
(
0
,)
if
layout
[
1
]
==
"N"
else
(
1
,)
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
return
(
x
,
w
,
contracting_dims
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_bf16
(
self
,
m
,
n
,
k
,
layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
False
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
=
contracting_dims
,
quantizer_set
=
quantizer_set
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
class
TestActivationLuFP8
(
TestActivationLu
):
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
def
prim_func
(
self
,
x
):
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
amax
=
self
.
amax
def
test_dense_grad_bf16
(
self
,
m
,
n
,
k
):
scale
=
self
.
scale
layout
=
"NN"
scale_inv
=
self
.
scale_inv
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
activation_type
=
self
.
activation_type
@
jax
.
custom_vjp
def
primitive_func
(
x
,
w
,
contracting_dims
):
def
_prim_func
(
x
,
_x_t
,
_dbias
,
_amax
):
primitive_out
=
dense
(
x
,
w
,
contracting_dims
=
contracting_dims
)
output
=
_prim_func_fwd
(
x
,
_x_t
,
_dbias
,
_amax
)
return
jnp
.
mean
(
primitive_out
)
return
output
def
_prim_func_fwd
(
x
,
_x_t
,
_dbias
,
_amax
):
def
ref_func
(
x
,
w
,
layout
):
activation_lu_out
,
_
=
tex
.
act_lu_fp8
(
return
jnp
.
mean
(
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
))
x
,
amax
,
scale
,
scale_inv
,
FP8Helper
.
FWD_DTYPE
,
activation_type
)
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
))
activation_lu_out
=
dequantize
(
activation_lu_out
,
x
.
dtype
,
scale_inv
)
ctx
=
x
return
activation_lu_out
,
ctx
def
_prim_func_bwd
(
ctx
,
g
):
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
))
x
=
ctx
if
len
(
self
.
activation_type
)
>
1
:
# gated, no bias
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
)
=
value_n_grad_primitive_func
(
dactivation_lu
,
dactivation_lu_trans
,
amax_out
=
tex
.
dgated_act_lu_cast_transpose
(
x
,
w
,
contracting_dims
g
,
x
,
amax
,
scale
,
scale_inv
,
FP8Helper
.
BWD_DTYPE
,
-
1
,
activation_type
)
)
dbias
=
jnp
.
empty
(
x
.
shape
[
-
1
],
x
.
dtype
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
layout
)
else
:
# not gated, with bias
dactivation_lu
,
dactivation_lu_trans
,
dbias
,
amax_out
=
(
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
tex
.
dact_lu_dbias_cast_transpose
(
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
bfloat16
)
g
,
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
jnp
.
bfloat16
)
x
,
amax
,
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
scale
,
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
scale_inv
,
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
FP8Helper
.
BWD_DTYPE
,
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
-
1
,
def
test_dense_grad_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
):
self
.
activation_type
,
layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
key
=
jax
.
random
.
PRNGKey
(
1
)
bias
=
jax
.
random
.
uniform
(
key
,
n
,
dtype
=
jnp
.
bfloat16
)
def
primitive_func
(
x
,
w
,
bias
,
contracting_dims
,
quantizer_set
):
primitive_out
=
dense
(
x
,
w
,
bias
,
contracting_dims
=
contracting_dims
,
quantizer_set
=
quantizer_set
)
)
return
jnp
.
mean
(
primitive_out
)
def
ref_func
(
x
,
w
,
bias
,
layout
):
return
jnp
.
mean
(
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
+
jnp
.
expand_dims
(
bias
,
axis
=
0
)
)
)
dactivation_lu
=
dequantize
(
dactivation_lu
,
x
.
dtype
,
scale_inv
)
dactivation_lu_trans
=
dequantize
(
dactivation_lu_trans
,
x
.
dtype
,
scale_inv
)
ctx
=
(
dactivation_lu
,
dactivation_lu_trans
,
dbias
,
amax_out
)
return
ctx
_prim_func
.
defvjp
(
_prim_func_fwd
,
_prim_func_bwd
)
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
))
dx_trans_no_use
=
jnp
.
empty
([
x
.
shape
[
i
]
for
i
in
self
.
transpose_axes
],
dtype
=
x
.
dtype
)
quantizer_set
=
QuantizerFactory
.
create_set
(
dbias_no_use
=
jnp
.
empty
(
x
.
shape
[
-
1
],
dtype
=
x
.
dtype
)
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
True
amax_no_use
=
jnp
.
zeros
(
1
,
jnp
.
float32
)
value_n_grad_primitive_func
=
value_and_grad
(
lambda
a
,
b
,
c
,
d
:
jnp
.
mean
(
_prim_func
(
a
,
b
,
c
,
d
)),
(
0
,
1
,
2
,
3
)
)
)
return
value_n_grad_primitive_func
(
x
,
dx_trans_no_use
,
dbias_no_use
,
amax_no_use
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
else
1
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
1
,
64
),
(
16
,
64
,
1
,
256
)])
for
_
in
range
(
n_iterations
):
@
pytest
.
mark
.
parametrize
(
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
,
primitive_bias_grad
)
=
(
"activation_type"
,
value_n_grad_primitive_func
(
x
,
w
,
bias
,
contracting_dims
,
quantizer_set
)
[
(
"gelu"
,),
(
"gelu"
,
"linear"
),
(
"silu"
,),
(
"silu"
,
"linear"
),
(
"relu"
,),
(
"relu"
,
"linear"
),
(
"quick_gelu"
,),
(
"quick_gelu"
,
"linear"
),
(
"squared_relu"
,),
(
"squared_relu"
,
"linear"
),
],
)
)
def
test_activation_lu
(
self
,
random_inputs
,
activation_type
):
self
.
amax
=
jnp
.
zeros
(
1
,
jnp
.
float32
)
self
.
scale
=
jnp
.
ones
(
1
,
jnp
.
float32
)
self
.
scale_inv
=
jnp
.
ones
(
1
,
jnp
.
float32
)
self
.
activation_type
=
activation_type
x
=
random_inputs
ref_out
,
(
ref_x_grad
,
ref_w_grad
,
ref_bias_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
bias
,
layout
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
axes
=
jnp
.
arange
(
x
.
ndim
)
self
.
transpose_axes
=
tuple
([
*
axes
[
-
2
:]]
+
[
*
axes
[:
-
2
]])
print
(
self
.
transpose_axes
)
prim_out
,
(
prim_grad
,
prim_grad_trans
,
dbias
,
amax
)
=
self
.
prim_func
(
x
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
ref_out
,
(
ref_grad
,)
=
self
.
ref_func
(
x
,
activation_type
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_bias_grad
,
ref_bias_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
FP8Helper
.
FWD_DTYPE
)
assert_allclose
(
amax
,
jnp
.
amax
(
jnp
.
abs
(
ref_grad
)),
rtol
=
1e-2
)
if
"linear"
not
in
activation_type
:
assert_allclose
(
dbias
,
jnp
.
sum
(
ref_grad
,
axis
=
(
i
for
i
in
range
(
x
.
ndim
-
1
))))
assert_allclose
(
prim_grad
,
ref_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
assert_allclose
(
prim_grad_trans
,
jnp
.
transpose
(
ref_grad
,
self
.
transpose_axes
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
)
@
pytest
.
fixture
(
name
=
"random_inputs"
)
def
random_inputs_fixture
(
shape
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
4
)
out
=
jax
.
random
.
uniform
(
subkeys
[
0
],
shape
,
jnp
.
bfloat16
,
5
,
8
)
return
out
class
TestNorm
:
"""
Test transformer_engine.jax.layernorm APIs
"""
@
staticmethod
def
_ref_jax_norm_impl
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
):
def
_generate_fp8_meta
():
if
norm_type
==
"rmsnorm"
:
fp8_dtype_list
=
[
FP8Helper
.
FWD_DTYPE
,
FP8Helper
.
FWD_DTYPE
,
FP8Helper
.
BWD_DTYPE
]
ln_out
,
_
=
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
eps
,
quantizer
)
amax_list
=
[
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
]
scale_list
=
[
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
]
return
fp8_dtype_list
,
amax_list
,
scale_list
def
reference_layernorm
(
self
,
x
,
scale
,
bias
,
zero_centered_gamma
,
eps
):
"""
JAX native layernorm implementations
- bias is not None: layernorm
- bias is None: rmsnorm
"""
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
if
bias
is
None
:
mean
=
0.0
else
:
else
:
mean
=
jnp
.
mean
(
x_
,
axis
=-
1
,
keepdims
=
True
)
ln_out
,
_
,
_
=
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
eps
,
quantizer
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
if
isinstance
(
ln_out
,
ScaledTensor
):
normed_input
=
(
x_
-
mean
)
*
jax
.
lax
.
rsqrt
(
var
+
eps
)
ln_out
=
ln_out
.
dequantize
()
if
zero_centered_gamma
:
return
ln_out
scale
+=
1.0
if
bias
is
None
:
bias
=
0.0
class
TestFusedDense
:
return
jnp
.
asarray
(
normed_input
*
scale
+
bias
).
astype
(
x
.
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
512
,
128
,
128
)])
@
pytest
.
mark
.
parametrize
(
"n, hidden"
,
LN_CASES
)
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"ln_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
def
test_layernorm_dense_grad
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
norm_type
):
@
pytest
.
mark
.
parametrize
(
"epsilon"
,
[
1e-2
,
1e-6
])
def
test_layernorm_forward_backward
(
self
,
n
,
hidden
,
ln_type
,
zero_centered_gamma
,
epsilon
,
dtype
):
"""
"""
Test
transformer_engine.jax.layernorm.layernorm
Test
layernorm_dense VJP Rule
"""
"""
expect_assert
=
False
# No Norm FWD E5M2 in TE backend
if
ln_type
==
"rmsnorm"
and
zero_centered_gamma
:
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
expect_assert
=
True
# zero_centered_gamma is already tested in TestNorm
with
(
zero_centered_gamma
=
False
pytest
.
raises
(
AssertionError
,
match
=
r
".*zero_centered_gamma is not supported.*"
)
eps
=
1e-6
if
expect_assert
else
nullcontext
()
):
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
3
)
subkeys
=
jax
.
random
.
split
(
key
,
4
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
(
n
,
hidden
),
dtype
,
-
1
,
1
)
# NN in FWD
gamma_range
=
(
-
1
,
1
)
if
zero_centered_gamma
else
(
0
,
2
)
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
)).
astype
(
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
k
)
gamma
=
jax
.
random
.
uniform
(
subkeys
[
1
],
(
hidden
,),
jnp
.
float32
,
*
gamma_range
)
w
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
n
)).
astype
(
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
n
)
gamma
=
jnp
.
asarray
(
gamma
,
dtype
)
if
ln_type
==
"layernorm"
:
beta
=
jax
.
random
.
uniform
(
subkeys
[
2
],
(
hidden
,),
jnp
.
float32
,
-
1
,
1
)
beta
=
jnp
.
asarray
(
beta
,
dtype
)
else
:
beta
=
None
def
compute_loss
(
x
):
gamma
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
k
,)).
astype
(
jnp
.
bfloat16
)
# Higher precision to compute the loss
x_
=
x
.
astype
(
jnp
.
float32
)
return
jnp
.
mean
(
jnp
.
square
(
x_
)).
astype
(
x
.
dtype
)
jitted_primitive
=
jit
(
quantizer_set
=
QuantizerFactory
.
create_set
(
value_and_grad
(
scaling_mode
=
scaling_mode
,
lambda
x
,
gamma
,
beta
:
compute_loss
(
fwd_dtype
=
q_dtype
,
layernorm
(
x
,
gamma
,
beta
,
ln_type
,
zero_centered_gamma
,
epsilon
)
bwd_dtype
=
q_dtype
,
),
is_2x2x
=
True
,
(
0
,
1
,
2
),
)
)
)
jitted_reference
=
jit
(
if
norm_type
==
"layernorm"
:
value_and_grad
(
beta
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
k
,)).
astype
(
jnp
.
bfloat16
)
lambda
x
,
gamma
,
beta
:
compute_loss
(
else
:
self
.
reference_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
)
beta
=
None
),
(
0
,
1
,
2
),
def
prim_func
(
x
,
w
,
gamma
,
beta
):
)
# bias = None as quantize_dbias is already tested in test_dense_grad_fp8
prim_out
=
layernorm_dense
(
x
,
w
,
gamma
,
beta
,
None
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer_set
=
quantizer_set
,
)
)
return
jnp
.
mean
(
prim_out
)
primitive_out
,
(
primitive_dx
,
primitive_dgamma
,
primitive_dbeta
)
=
jitted_primitive
(
def
ref_func
(
x
,
w
,
gamma
,
beta
):
x
,
gamma
,
beta
x
=
_ref_jax_norm_impl
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
=
None
)
)
reference_out
,
(
reference_dx
,
reference_dgamma
,
reference_dbeta
)
=
jitted_reference
(
return
jnp
.
mean
(
jnp
.
dot
(
x
,
w
))
x
,
gamma
,
beta
value_n_grad_prim_func
=
value_and_grad
(
prim_func
,
(
0
,
1
,
2
,
3
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
,
3
))
ref_out
,
(
ref_x_grad
,
ref_w_grad
,
ref_gamma_grad
,
ref_beta_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
gamma
,
beta
)
)
assert_allclose
(
primitive_out
,
reference_out
,
dtype
=
dtype
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
else
1
assert_allclose
(
primitive_dx
,
reference_dx
,
dtype
=
dtype
)
for
_
in
range
(
n_iterations
):
assert_allclose
(
primitive_dgamma
,
reference_dgamma
,
dtype
=
dtype
)
prim_out
,
(
prim_x_grad
,
prim_w_grad
,
prim_gamma_grad
,
prim_beta_grad
,
)
=
value_n_grad_prim_func
(
x
,
w
,
gamma
,
beta
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_w_grad
,
ref_w_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
q_dtype
)
if
beta
is
not
None
:
if
beta
is
not
None
:
assert_allclose
(
prim
itive_dbeta
,
reference_dbeta
,
dtype
=
dtype
)
assert_allclose
(
prim
_beta_grad
,
ref_beta_grad
,
dtype
=
q_
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
GEMM_CASES
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest
.
mark
.
parametrize
(
"ln_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"epsilon"
,
[
1e-2
,
1e-6
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
def
test_ln_fp8_dot_forward_backward
(
self
,
m
,
n
,
k
,
ln_type
,
zero_centered_gamma
,
epsilon
):
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_layernorm_mlp_grad
(
self
,
m
,
n
,
k
,
activation_type
,
q_dtype
,
scaling_mode
,
norm_type
,
use_bias
):
"""
"""
Test
transformer_engine.jax.layernorm.layernorm_fp8_dot
Test
layernorm_mlp VJP Rule
"""
"""
expect_assert
=
False
# No Norm FWD E5M2 in TE backend
if
ln_type
==
"rmsnorm"
and
zero_centered_gamma
:
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
expect_assert
=
True
# zero_centered_gamma is already tested in TestNorm
with
(
zero_centered_gamma
=
False
pytest
.
raises
(
AssertionError
,
match
=
r
".*zero_centered_gamma is not supported.*"
)
eps
=
1e-6
if
expect_assert
else
nullcontext
()
):
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
4
)
subkeys
=
jax
.
random
.
split
(
key
,
6
)
a
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
)).
astype
(
jnp
.
bfloat16
)
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
b
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
n
)).
astype
(
jnp
.
bfloat16
)
kernel_1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
len
(
activation_type
)
*
n
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
k
)
kernel_2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
n
,
k
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
n
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
k
,),
jnp
.
bfloat16
)
beta
=
None
# was tested in TestNorm
if
use_bias
:
bias_1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
*
n
),
jnp
.
bfloat16
)
bias_2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
k
,),
jnp
.
bfloat16
)
else
:
bias_1
=
None
bias_2
=
None
quantizer_sets
=
QuantizerFactory
.
create_set
(
n_quantizer_sets
=
2
,
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
True
,
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
k
,)).
astype
(
jnp
.
bfloat16
)
if
norm_type
==
"layernorm"
:
if
ln_type
==
"layernorm"
:
beta
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
k
,)).
astype
(
jnp
.
bfloat16
)
beta
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
k
,)).
astype
(
jnp
.
bfloat16
)
else
:
else
:
beta
=
None
beta
=
None
_
,
amax_list_1
,
scale_list_1
=
TestNorm
.
_generate_fp8_meta
()
def
prim_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
):
return
jnp
.
mean
(
def
primitive_func
(
x
,
y
,
gamma
,
beta
,
amax_list_1
,
scale_list_1
):
layernorm_mlp
(
fp8_meta_pkg
=
FP8MetaPackage
(
x
,
amax_list_1
[
0
],
gamma
,
scale_list_1
[
0
],
beta
,
amax_list_1
[
1
],
[
kernel_1
,
kernel_2
],
scale_list_1
[
1
],
[
bias_1
,
bias_2
],
amax_list_1
[
2
],
norm_type
,
scale_list_1
[
2
],
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
eps
,
activation_type
=
activation_type
,
quantizer_sets
=
quantizer_sets
,
)
)
primitive_out
=
layernorm_fp8_dot
(
x
,
y
,
gamma
,
beta
,
fp8_meta_pkg
,
ln_type
,
zero_centered_gamma
)
)
return
jnp
.
mean
(
primitive_out
)
def
ref_func
(
x
,
y
,
gamma
,
beta
,
zero_centered_gamma
):
def
_ref_func_impl
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
):
x
=
self
.
reference_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
)
ln_out
=
_ref_jax_norm_impl
(
return
jnp
.
mean
(
jnp
.
dot
(
x
,
y
))
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
=
None
)
# TODO: replace gemm with jnp.dot
linear_1_out
=
tex
.
gemm
(
ln_out
,
kernel_1
,
((
1
,),
(
0
,)))
if
use_bias
:
bias_1_shape
=
(
1
,)
*
(
linear_1_out
.
ndim
-
bias_1
.
ndim
)
+
bias_1
.
shape
linear_1_out
+=
jnp
.
reshape
(
bias_1
,
bias_1_shape
)
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
range
(
6
))
x
=
_jax_act_lu
(
linear_1_out
,
activation_type
)
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
,
3
))
linear_2_out
=
tex
.
gemm
(
x
,
kernel_2
,
((
1
,),
(
0
,)))
if
use_bias
:
bias_2_shape
=
(
1
,)
*
(
linear_2_out
.
ndim
-
bias_2
.
ndim
)
+
bias_2
.
shape
linear_2_out
+=
jnp
.
reshape
(
bias_2
,
bias_2_shape
)
return
linear_2_out
def
ref_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
):
return
jnp
.
mean
(
_ref_func_impl
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
))
value_n_grad_prim_func
=
value_and_grad
(
prim_func
,
range
(
6
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
range
(
6
))
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_x_grad
,
prim_gamma_grad
,
prim_kernel_1_grad
,
prim_kernel_2_grad
,
prim_bias_1_grad
,
prim_bias_2_grad
,
)
=
value_n_grad_prim_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
)
ref_out
,
(
ref_x_grad
,
ref_gamma_grad
,
ref_kernel_1_grad
,
ref_kernel_2_grad
,
ref_bias_1_grad
,
ref_bias_2_grad
,
)
=
value_n_grad_ref_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
q_dtype
)
ref_out
,
(
ref_a_grad
,
ref_b_grad
,
ref_gamma_grad
,
ref_beta_grad
)
=
(
assert_allclose
(
prim_kernel_2_grad
,
ref_kernel_2_grad
,
dtype
=
q_dtype
)
value_n_grad_ref_func
(
a
,
b
,
gamma
,
beta
,
zero_centered_gamma
)
if
use_bias
:
assert_allclose
(
prim_bias_2_grad
,
ref_bias_2_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_kernel_1_grad
,
ref_kernel_1_grad
,
dtype
=
q_dtype
)
if
use_bias
:
assert_allclose
(
prim_bias_1_grad
,
ref_bias_1_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def
_quantize_gemm_pair
(
lhs
,
rhs
,
contracting_dims
,
lhs_quantizer
,
rhs_quantizer
):
((
lhs_contract_dim
,),
(
rhs_contract_dim
,))
=
contracting_dims
lhs_is_rowwise
=
lhs_contract_dim
==
lhs
.
ndim
-
1
rhs_is_rowwise
=
rhs_contract_dim
==
rhs
.
ndim
-
1
lhs_q
=
lhs_quantizer
.
quantize
(
lhs
,
is_rowwise
=
lhs_is_rowwise
,
is_colwise
=
not
lhs_is_rowwise
,
)
rhs_q
=
rhs_quantizer
.
quantize
(
rhs
,
is_rowwise
=
rhs_is_rowwise
,
is_colwise
=
not
rhs_is_rowwise
,
)
)
return
lhs_q
,
rhs_q
for
_
in
range
(
3
):
primitive_out
,
(
primitive_a_grad
,
primitive_b_grad
,
primitive_gamma_grad
,
primitive_beta_grad
,
amax_list_1
,
scale_list_1
,
)
=
value_n_grad_primitive_func
(
a
,
b
,
gamma
,
beta
,
amax_list_1
,
scale_list_1
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
FP8Helper
.
FWD_DTYPE
)
# E5M2 * E5M2 is not supported
assert_allclose
(
primitive_a_grad
,
ref_a_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
fwd_bwd_dtypes
=
[
assert_allclose
(
primitive_b_grad
,
ref_b_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e4m3fn
],
assert_allclose
(
primitive_gamma_grad
,
ref_gamma_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
],
if
beta
is
not
None
:
[
jnp
.
float8_e5m2
,
jnp
.
float8_e4m3fn
],
assert_allclose
(
primitive_beta_grad
,
ref_beta_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
]
@
pytest
.
mark
.
parametrize
(
@
pytest_parametrize_wrapper
(
"in_dtype"
,
"shape_list"
,
[[(
512
,
128
,
256
),
(
256
,
128
,
256
),
(
256
,
128
,
128
),
(
512
,
256
,
128
)]]
[
pytest
.
param
(
jnp
.
float32
,
id
=
"input_float32"
),
pytest
.
param
(
jnp
.
float16
,
id
=
"input_float16"
),
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"input_bfloat16"
),
],
)
)
@
pytest
.
mark
.
parametrize
(
class
TestGroupedDense
:
"input_shape, transpose_axis"
,
def
_ref_grouped_gemm_with_jnp_dot
(
self
,
lhs_list
,
rhs_list
,
contracting_dims_list
):
[
ref_out_list
=
[]
pytest
.
param
((
16
,
16
),
1
,
id
=
"(16, 16)-1"
),
for
lhs
,
rhs
,
contracting_dims
in
zip
(
lhs_list
,
rhs_list
,
contracting_dims_list
):
pytest
.
param
((
256
,
128
),
1
,
id
=
"(256, 128)-1"
),
dim_nums
=
(
contracting_dims
,
((),
()))
pytest
.
param
((
128
,
512
),
1
,
id
=
"(128, 512)-1"
),
ref_out_list
.
append
(
jax
.
lax
.
dot_general
(
lhs
,
rhs
,
dim_nums
))
pytest
.
param
((
64
,
16
,
4
,
256
),
1
,
id
=
"(64, 16, 4, 256)-1"
),
return
ref_out_list
pytest
.
param
((
64
,
16
,
4
,
256
),
2
,
id
=
"(64, 16, 4, 256)-2"
),
pytest
.
param
((
64
,
16
,
4
,
256
),
3
,
id
=
"(64, 16, 4, 256)-3"
),
def
_generate_grouped_gemm_input
(
self
,
dtype
,
shape_list
,
layout_list
):
],
)
class
TestTranspose
:
def
test_transpose
(
self
,
in_dtype
,
input_shape
,
transpose_axis
):
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
input_tensor
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
subkeys
=
jax
.
random
.
split
(
key
,
len
(
shape_list
)
*
2
)
static_axis_boundary
=
-
1
jax_output
=
_jax_transpose
(
input_tensor
,
static_axis_boundary
,
transpose_axis
)
lhs_list
,
rhs_list
,
contracting_dims_list
=
[],
[],
[]
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"0"
for
i
,
((
m
,
n
,
k
),
layout
)
in
enumerate
(
zip
(
shape_list
,
layout_list
)):
noffi_output
=
tex
.
transpose
(
input_tensor
,
static_axis_boundary
,
transpose_axis
)
lhs
=
jax
.
random
.
uniform
(
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"1"
subkeys
[
2
*
i
],
ffi_output
=
tex
.
transpose
(
input_tensor
,
static_axis_boundary
,
transpose_axis
)
(
m
if
layout
[
0
]
==
"N"
else
k
,
k
if
layout
[
0
]
==
"N"
else
m
),
assert_allclose
(
jax_output
,
noffi_output
)
dtype
=
dtype
,
assert_allclose
(
noffi_output
,
ffi_output
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
pytest
.
param
(
jnp
.
float8_e4m3fn
,
id
=
"output_float8_e4m3fn"
),
pytest
.
param
(
jnp
.
float8_e5m2
,
id
=
"output_float8_e5m2"
),
],
)
)
def
test_cast_transpose
(
self
,
in_dtype
,
input_shape
,
transpose_axis
,
out_dtype
):
rhs
=
jax
.
random
.
uniform
(
amax
=
jnp
.
zeros
(
1
,
jnp
.
float32
)
subkeys
[
2
*
i
+
1
],
scale
=
jnp
.
ones
(
1
,
jnp
.
float32
)
(
k
if
layout
[
1
]
==
"N"
else
n
,
n
if
layout
[
1
]
==
"N"
else
k
),
scale_inv
=
jnp
.
ones
(
1
,
jnp
.
float32
)
dtype
=
dtype
,
key
=
jax
.
random
.
PRNGKey
(
0
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
static_axis_boundary
=
-
1
jax_output
=
_jax_cast_transpose
(
input
,
scale
,
amax
,
out_dtype
,
static_axis_boundary
,
transpose_axis
)
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"0"
lhs_contracting_dim
=
(
1
,)
if
layout
[
0
]
==
"N"
else
(
0
,)
noffi_output
=
tex
.
cast_transpose
(
rhs_contracting_dim
=
(
0
,)
if
layout
[
1
]
==
"N"
else
(
1
,)
input
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
transpose_axis
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
lhs_list
.
append
(
lhs
)
rhs_list
.
append
(
rhs
)
contracting_dims_list
.
append
(
contracting_dims
)
return
lhs_list
,
rhs_list
,
contracting_dims_list
@
pytest_parametrize_wrapper
(
"dtype"
,
[
jnp
.
bfloat16
,
jnp
.
float16
])
@
pytest_parametrize_wrapper
(
"layout_list"
,
[[
"NN"
,
"TN"
,
"NT"
,
"TT"
]])
def
test_grouped_gemm_fp16
(
self
,
dtype
,
shape_list
,
layout_list
):
lhs_list
,
rhs_list
,
contracting_dims_list
=
self
.
_generate_grouped_gemm_input
(
dtype
,
shape_list
,
layout_list
)
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"1"
ref_out
=
self
.
_ref_grouped_gemm_with_jnp_dot
(
lhs_list
,
rhs_list
,
contracting_dims_list
)
ffi_output
=
tex
.
cast_transpose
(
primitive_out
=
tex
.
grouped_gemm
(
lhs_list
,
rhs_list
,
contracting_dims_list
)
input
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
transpose_axis
for
i
in
range
(
len
(
shape_list
)):
assert_allclose
(
primitive_out
[
i
],
ref_out
[
i
],
dtype
=
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fwd_bwd_dtype"
,
fwd_bwd_dtypes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"layout_list"
,
[[
"NN"
,
"TN"
,
"NT"
,
"TT"
]])
def
test_grouped_gemm_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
shape_list
,
layout_list
):
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
fwd_dtype
,
bwd_dtype
=
bwd_dtype
,
is_2x2x
=
False
)
)
assert_tree_like_allclose
(
jax_output
,
ffi_output
)
assert_tree_like_allclose
(
noffi_output
,
ffi_output
)
@
pytest
.
mark
.
parametrize
(
out_dtype
=
jnp
.
bfloat16
"out_dtype"
,
lhs_list
,
rhs_list
,
contracting_dims_list
=
self
.
_generate_grouped_gemm_input
(
[
out_dtype
,
shape_list
,
layout_list
pytest
.
param
(
jnp
.
float8_e4m3fn
,
id
=
"output_float8_e4m3fn"
),
pytest
.
param
(
jnp
.
float8_e5m2
,
id
=
"output_float8_e5m2"
),
],
)
)
def
test_dbias_cast_transpose
(
self
,
in_dtype
,
input_shape
,
transpose_axis
,
out_dtype
):
q_lhs_list
=
[]
amax
=
jnp
.
zeros
(
1
,
jnp
.
float32
)
q_rhs_list
=
[]
scale
=
jnp
.
ones
(
1
,
jnp
.
float32
)
for
lhs
,
rhs
,
contracting_dims
in
zip
(
lhs_list
,
rhs_list
,
contracting_dims_list
):
scale_inv
=
jnp
.
ones
(
1
,
jnp
.
float32
)
# quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to
key
=
jax
.
random
.
PRNGKey
(
0
)
# test the case where lhs and rhs have different q_dtypes
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
q_lhs
,
q_rhs
=
_quantize_gemm_pair
(
static_axis_boundary
=
-
1
lhs
,
rhs
,
contracting_dims
,
quantizer_set
.
x
,
quantizer_set
.
dgrad
jax_output
=
_jax_dbias_cast_transpose
(
input
,
amax
,
scale
,
out_dtype
,
static_axis_boundary
,
transpose_axis
)
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"0"
q_lhs_list
.
append
(
q_lhs
)
noffi_output
=
tex
.
dbias_cast_transpose
(
q_rhs_list
.
append
(
q_rhs
)
input
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
transpose_axis
ref_out
=
self
.
_ref_grouped_gemm_with_jnp_dot
(
lhs_list
,
rhs_list
,
contracting_dims_list
)
primitive_out
=
tex
.
grouped_gemm
(
q_lhs_list
,
q_rhs_list
,
contracting_dims_list
)
allclose_dtype
=
jnp
.
float8_e4m3fn
if
fwd_dtype
==
jnp
.
float8_e5m2
or
bwd_dtype
==
jnp
.
float8_e5m2
:
allclose_dtype
=
jnp
.
float8_e5m2
for
i
in
range
(
len
(
shape_list
)):
assert_allclose
(
primitive_out
[
i
],
ref_out
[
i
],
dtype
=
allclose_dtype
)
@
pytest_parametrize_wrapper
(
"dtype"
,
[
jnp
.
bfloat16
,
jnp
.
float16
])
def
test_grouped_dense_grad_fp16
(
self
,
dtype
,
shape_list
):
group_size
=
len
(
shape_list
)
layout_list
=
[
"NN"
for
_
in
range
(
group_size
)]
x_list
,
kernel_list
,
contracting_dims_list
=
self
.
_generate_grouped_gemm_input
(
dtype
,
shape_list
,
layout_list
)
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"1"
bias_list
=
[]
ffi_output
=
tex
.
dbias_cast_transpose
(
key
=
jax
.
random
.
PRNGKey
(
1
)
input
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
transpose_axis
for
shape
in
shape_list
:
n
=
shape
[
1
]
bias
=
jax
.
random
.
uniform
(
key
,
n
,
dtype
=
dtype
)
bias_list
.
append
(
bias
)
def
ref_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
):
out_list
=
[]
for
i
in
range
(
len
(
x_list
)):
out_list
.
append
(
dense
(
x_list
[
i
],
kernel_list
[
i
],
bias_list
[
i
],
contracting_dims
=
contracting_dims_list
[
i
],
)
)
assert_tree_like_allclose
(
jax_output
,
ffi_output
)
)
assert_tree_like_allclose
(
noffi_output
,
ffi_output
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
def
primitive_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
):
out_list
=
grouped_dense
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
)
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
))
@
pytest
.
mark
.
parametrize
(
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
))
"input_shape"
,
[
ref_out_mean
,
(
ref_dgrad_list
,
ref_wgrad_list
,
ref_dbias_list
)
=
value_n_grad_ref_func
(
pytest
.
param
((
256
,
128
),
id
=
"(256, 128)"
),
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
pytest
.
param
((
128
,
512
,
8
),
id
=
"(128, 512, 8)"
),
)
],
primitive_out_mean
,
(
primitive_dgrad_list
,
primitive_wgrad_list
,
primitive_dbias_list
)
=
(
)
value_n_grad_primitive_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
)
@
pytest
.
mark
.
parametrize
(
)
"in_dtype"
,
[
assert_allclose
(
primitive_out_mean
,
ref_out_mean
,
dtype
=
dtype
)
pytest
.
param
(
jnp
.
float32
,
id
=
"input_float32"
),
for
i
in
range
(
group_size
):
pytest
.
param
(
jnp
.
float16
,
id
=
"input_float16"
),
assert_allclose
(
primitive_dgrad_list
[
i
],
ref_dgrad_list
[
i
],
dtype
=
dtype
)
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"input_bfloat16"
),
assert_allclose
(
primitive_wgrad_list
[
i
],
ref_wgrad_list
[
i
],
dtype
=
dtype
)
],
assert_allclose
(
primitive_dbias_list
[
i
],
ref_dbias_list
[
i
],
dtype
=
dtype
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
"out_dtype"
,
@
pytest
.
mark
.
parametrize
(
"fwd_bwd_dtype"
,
fwd_bwd_dtypes
)
[
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
pytest
.
param
(
jnp
.
float8_e4m3fn
,
id
=
"output_float8_e4m3fn"
),
def
test_grouped_dense_grad_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
shape_list
):
pytest
.
param
(
jnp
.
float8_e5m2
,
id
=
"output_float8_e5m2"
),
group_size
=
len
(
shape_list
)
],
layout_list
=
[
"NN"
for
_
in
range
(
group_size
)]
)
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
def
test_quantize
(
input_shape
,
in_dtype
,
out_dtype
):
if
fwd_dtype
==
jnp
.
float8_e5m2
:
amax
=
jnp
.
zeros
(
1
,
jnp
.
float32
)
pytest
.
skip
(
"We never use E5M2 for fwd_dtype in training"
)
scale
=
jnp
.
ones
(
1
,
jnp
.
float32
)
scale_inv
=
jnp
.
ones
(
1
,
jnp
.
float32
)
# Question: should we use different quantizers for different groups?
key
=
jax
.
random
.
PRNGKey
(
0
)
ref_quantizer_set_list
=
[]
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
quantizer_set_list
=
[]
jax_output
=
_jax_cast_fp8
(
input
,
scale
,
amax
,
out_dtype
)
for
_
in
range
(
group_size
):
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"0"
ref_quantizer_set
=
QuantizerFactory
.
create_set
(
noffi_output
=
tex
.
cast_fp8
(
input
,
amax
,
scale
,
scale_inv
,
out_dtype
)
scaling_mode
=
scaling_mode
,
fwd_dtype
=
fwd_dtype
,
bwd_dtype
=
bwd_dtype
,
is_2x2x
=
True
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"1"
)
ffi_output
=
tex
.
cast_fp8
(
input
,
amax
,
scale
,
scale_inv
,
out_dtype
)
ref_quantizer_set_list
.
append
(
ref_quantizer_set
)
assert_tree_like_allclose
(
jax_output
,
ffi_output
)
quantizer_set
=
QuantizerFactory
.
create_set
(
assert_tree_like_allclose
(
noffi_output
,
ffi_output
)
scaling_mode
=
scaling_mode
,
fwd_dtype
=
fwd_dtype
,
bwd_dtype
=
bwd_dtype
,
is_2x2x
=
True
)
quantizer_set_list
.
append
(
quantizer_set
)
out_dtype
=
jnp
.
bfloat16
x_list
,
kernel_list
,
contracting_dims_list
=
self
.
_generate_grouped_gemm_input
(
out_dtype
,
shape_list
,
layout_list
)
bias_list
=
[]
key
=
jax
.
random
.
PRNGKey
(
1
)
for
shape
in
shape_list
:
n
=
shape
[
1
]
bias
=
jax
.
random
.
uniform
(
key
,
n
,
dtype
=
out_dtype
)
bias_list
.
append
(
bias
)
def
ref_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
):
out_list
=
[]
for
i
in
range
(
len
(
x_list
)):
out_list
.
append
(
dense
(
x_list
[
i
],
kernel_list
[
i
],
bias_list
[
i
],
contracting_dims
=
contracting_dims_list
[
i
],
quantizer_set
=
quantizer_set_list
[
i
],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
def
primitive_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
):
out_list
=
grouped_dense
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
)
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
))
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
))
ref_out_mean
,
(
ref_dgrad_list
,
ref_wgrad_list
,
ref_dbias_list
)
=
value_n_grad_ref_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
ref_quantizer_set_list
)
primitive_out_mean
,
(
primitive_dgrad_list
,
primitive_wgrad_list
,
primitive_dbias_list
)
=
(
value_n_grad_primitive_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
)
)
allclose_dtype
=
jnp
.
float8_e4m3fn
if
fwd_dtype
==
jnp
.
float8_e5m2
or
bwd_dtype
==
jnp
.
float8_e5m2
:
allclose_dtype
=
jnp
.
float8_e5m2
assert_allclose
(
primitive_out_mean
,
ref_out_mean
,
dtype
=
allclose_dtype
)
for
i
in
range
(
group_size
):
assert_allclose
(
primitive_dgrad_list
[
i
],
ref_dgrad_list
[
i
],
dtype
=
allclose_dtype
)
assert_allclose
(
primitive_wgrad_list
[
i
],
ref_wgrad_list
[
i
],
dtype
=
allclose_dtype
)
assert_allclose
(
primitive_dbias_list
[
i
],
ref_dbias_list
[
i
],
dtype
=
allclose_dtype
)
tests/jax/test_distributed_fused_attn.py
View file @
a207db1d
...
@@ -6,7 +6,6 @@ import os
...
@@ -6,7 +6,6 @@ import os
import
pytest
import
pytest
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
numpy
as
np
from
jax
import
random
from
jax
import
random
from
distributed_test_base
import
(
from
distributed_test_base
import
(
generate_configs
,
generate_configs
,
...
@@ -104,7 +103,7 @@ class TestDistributedSelfAttn:
...
@@ -104,7 +103,7 @@ class TestDistributedSelfAttn:
hidden
,
hidden
,
None
,
# no window
None
,
# no window
):
):
pytest
.
skip
(
f
"No FusedAttn backend found"
)
pytest
.
skip
(
"No FusedAttn backend found"
)
col_ref
=
self
.
generate_collectives_count_ref
(
col_ref
=
self
.
generate_collectives_count_ref
(
mesh_shape
,
mesh_shape
,
...
@@ -176,7 +175,7 @@ class TestDistributedCrossAttn:
...
@@ -176,7 +175,7 @@ class TestDistributedCrossAttn:
hidden
,
hidden
,
None
,
# no window
None
,
# no window
):
):
pytest
.
skip
(
f
"No FusedAttn backend found"
)
pytest
.
skip
(
"No FusedAttn backend found"
)
col_ref
=
self
.
generate_collectives_count_ref
()
col_ref
=
self
.
generate_collectives_count_ref
()
runner
=
FusedAttnRunner
(
runner
=
FusedAttnRunner
(
...
@@ -256,7 +255,6 @@ class TestDistributedContextParallelSelfAttn:
...
@@ -256,7 +255,6 @@ class TestDistributedContextParallelSelfAttn:
dropout_prob
=
0.0
dropout_prob
=
0.0
is_training
=
True
is_training
=
True
dp_size
,
cp_size
,
tp_size
=
mesh_shape
dp_size
,
cp_size
,
tp_size
=
mesh_shape
qkv_format
=
qkv_layout
.
get_qkv_format
()
batch
,
seqlen
,
num_head
,
hidden
=
data_shape
batch
,
seqlen
,
num_head
,
hidden
=
data_shape
...
@@ -382,7 +380,7 @@ class TestDistributedContextParallelSelfAttn:
...
@@ -382,7 +380,7 @@ class TestDistributedContextParallelSelfAttn:
if
qkv_layout
.
is_thd
()
and
not
load_balanced
:
if
qkv_layout
.
is_thd
()
and
not
load_balanced
:
pytest
.
skip
(
"THD + ring doesn't support unbalanced context parallelism."
)
pytest
.
skip
(
"THD + ring doesn't support unbalanced context parallelism."
)
return
self
.
impl_test_context_parallel_attn
(
self
.
impl_test_context_parallel_attn
(
device_count
,
device_count
,
mesh_shape
,
mesh_shape
,
mesh_axes
,
mesh_axes
,
...
@@ -396,6 +394,7 @@ class TestDistributedContextParallelSelfAttn:
...
@@ -396,6 +394,7 @@ class TestDistributedContextParallelSelfAttn:
CPStrategy
.
RING
,
CPStrategy
.
RING
,
)
)
del
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
del
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
return
class
TestReorderCausalLoadBalancing
:
class
TestReorderCausalLoadBalancing
:
...
...
tests/jax/test_distributed_layernorm.py
View file @
a207db1d
...
@@ -13,11 +13,30 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
...
@@ -13,11 +13,30 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
from
distributed_test_base
import
generate_configs
,
generate_collectives_count
from
distributed_test_base
import
generate_configs
,
generate_collectives_count
from
distributed_test_base
import
compare_ops
from
distributed_test_base
import
compare_ops
from
utils
import
pytest_parametrize_wrapper
from
transformer_engine.jax
import
fp8_autocast
from
transformer_engine.jax
import
fp8_autocast
from
transformer_engine.common
import
recipe
from
transformer_engine.jax.layernorm
import
layernorm
from
transformer_engine.jax.layernorm
import
layernorm
from
transformer_engine.jax.quantize
import
QuantizerFactory
,
ScalingMode
,
is_fp8_available
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
NORM_INPUT_SHAPES
=
{
"L0"
:
[[
64
,
64
]],
"L2"
:
[[
64
,
64
]],
}
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
if
is_mxfp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
class
TestDistributedLayernorm
:
class
TestDistributedLayernorm
:
...
@@ -41,25 +60,32 @@ class TestDistributedLayernorm:
...
@@ -41,25 +60,32 @@ class TestDistributedLayernorm:
return
(
x
,
gamma
,
beta
),
(
x_pspec
,
g_pspec
,
b_pspec
)
return
(
x
,
gamma
,
beta
),
(
x_pspec
,
g_pspec
,
b_pspec
)
def
generate_collectives_count_ref
(
self
,
mesh_resource
,
ln_type
,
shape
,
dtype
):
def
generate_collectives_count_ref
(
self
,
mesh_resource
,
ln_type
,
shape
,
dtype
,
mesh_axes
,
fp8_recipe
):
jax_dtype
=
jax
.
dtypes
.
canonicalize_dtype
(
dtype
)
jax_dtype
=
jax
.
dtypes
.
canonicalize_dtype
(
dtype
)
is_dp_enabled
=
mesh_resource
.
dp_resource
is
not
None
is_dp_enabled
=
mesh_resource
.
dp_resource
is
not
None
assert
ln_type
in
[
"layernorm"
,
"rmsnorm"
]
assert
ln_type
in
[
"layernorm"
,
"rmsnorm"
]
all_reduce_loss_bytes
=
4
# 1 * FP32
all_reduce_loss_bytes
=
4
# 1 * FP32
# for loss, dgamma and dbeta
# for loss, dgamma and dbeta
weight_count
=
2
if
ln_type
==
"layernorm"
else
1
# TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
weight_count
=
2
if
(
ln_type
==
"layernorm"
and
"dp"
in
mesh_axes
)
else
1
allreduce_total_bytes
=
(
allreduce_total_bytes
=
(
all_reduce_loss_bytes
+
weight_count
*
shape
[
-
1
]
*
jax_dtype
.
itemsize
all_reduce_loss_bytes
+
weight_count
*
shape
[
-
1
]
*
jax_dtype
.
itemsize
)
)
other_bytes
=
0
if
fp8_recipe
==
recipe
.
MXFP8BlockScaling
()
and
"dp"
in
mesh_axes
:
other_bytes
=
384
# required for small scale shapes that require padding
return
generate_collectives_count
(
return
generate_collectives_count
(
allreduce
=
allreduce_total_bytes
*
int
(
is_dp_enabled
),
allgather
=
0
,
other
=
0
allreduce
=
allreduce_total_bytes
*
int
(
is_dp_enabled
),
allgather
=
0
,
other
=
other_bytes
)
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
128
,
1024
],
[
32
,
1024
]])
@
pytest_parametrize_wrapper
(
"data_shape"
,
NORM_INPUT_SHAPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm
(
def
test_layernorm
(
self
,
self
,
device_count
,
device_count
,
...
@@ -70,12 +96,19 @@ class TestDistributedLayernorm:
...
@@ -70,12 +96,19 @@ class TestDistributedLayernorm:
dtype
,
dtype
,
zero_centered_gamma
,
zero_centered_gamma
,
shard_weights
,
shard_weights
,
fp8_recipe
,
):
):
epsilon
=
1e-6
epsilon
=
1e-6
ln_type
=
"layernorm"
ln_type
=
"layernorm"
q_dtype
=
jnp
.
float8_e4m3fn
def
target_func
(
x
,
gamma
,
beta
):
def
target_func
(
x
,
gamma
,
beta
):
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
beta
,
ln_type
,
zero_centered_gamma
,
epsilon
))
quantizer
=
QuantizerFactory
.
create_set
().
x
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
beta
,
ln_type
,
zero_centered_gamma
,
epsilon
,
quantizer
=
quantizer
)
)
def
ref_func
(
x
,
gamma
,
beta
):
def
ref_func
(
x
,
gamma
,
beta
):
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
...
@@ -92,11 +125,11 @@ class TestDistributedLayernorm:
...
@@ -92,11 +125,11 @@ class TestDistributedLayernorm:
data_shape
,
mesh_resource
,
dtype
,
shard_weights
data_shape
,
mesh_resource
,
dtype
,
shard_weights
)
)
collective_count_ref
=
self
.
generate_collectives_count_ref
(
collective_count_ref
=
self
.
generate_collectives_count_ref
(
mesh_resource
,
ln_type
,
data_shape
,
dtype
mesh_resource
,
ln_type
,
data_shape
,
dtype
,
mesh_axes
,
fp8_recipe
)
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
gamma_
=
jax
.
device_put
(
gamma
,
NamedSharding
(
mesh
,
g_pspec
))
gamma_
=
jax
.
device_put
(
gamma
,
NamedSharding
(
mesh
,
g_pspec
))
beta_
=
jax
.
device_put
(
beta
,
NamedSharding
(
mesh
,
b_pspec
))
beta_
=
jax
.
device_put
(
beta
,
NamedSharding
(
mesh
,
b_pspec
))
...
@@ -109,8 +142,8 @@ class TestDistributedLayernorm:
...
@@ -109,8 +142,8 @@ class TestDistributedLayernorm:
[
x_
,
gamma_
,
beta_
],
[
x_
,
gamma_
,
beta_
],
collective_count_ref
,
collective_count_ref
,
grad_args
=
(
0
,
1
,
2
),
grad_args
=
(
0
,
1
,
2
),
metric_fwd_dtype
=
dtype
,
metric_fwd_dtype
=
q_
dtype
,
metric_bwd_dtype
=
dtype
,
metric_bwd_dtype
=
q_
dtype
,
in_shardings
=
(
x_pspec
,
g_pspec
,
b_pspec
),
in_shardings
=
(
x_pspec
,
g_pspec
,
b_pspec
),
out_shardings
=
(
None
,
(
x_pspec
,
g_pspec
,
b_pspec
)),
out_shardings
=
(
None
,
(
x_pspec
,
g_pspec
,
b_pspec
)),
)
)
...
@@ -131,17 +164,28 @@ class TestDistributedLayernorm:
...
@@ -131,17 +164,28 @@ class TestDistributedLayernorm:
)
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
128
,
1024
],
[
32
,
1024
]])
@
pytest_parametrize_wrapper
(
"data_shape"
,
NORM_INPUT_SHAPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_rmsnorm
(
def
test_rmsnorm
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
dtype
,
shard_weights
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
dtype
,
shard_weights
,
fp8_recipe
,
):
):
epsilon
=
1e-6
epsilon
=
1e-6
ln_type
=
"rmsnorm"
ln_type
=
"rmsnorm"
q_dtype
=
jnp
.
float8_e4m3fn
def
target_func
(
x
,
gamma
):
def
target_func
(
x
,
gamma
):
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
None
,
ln_type
,
False
,
epsilon
))
quantizer
=
QuantizerFactory
.
create_set
().
x
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
None
,
ln_type
,
False
,
epsilon
,
quantizer
=
quantizer
))
def
ref_func
(
x
,
gamma
):
def
ref_func
(
x
,
gamma
):
x
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
x
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
...
@@ -154,11 +198,11 @@ class TestDistributedLayernorm:
...
@@ -154,11 +198,11 @@ class TestDistributedLayernorm:
data_shape
,
mesh_resource
,
dtype
,
shard_weights
data_shape
,
mesh_resource
,
dtype
,
shard_weights
)
)
collective_count_ref
=
self
.
generate_collectives_count_ref
(
collective_count_ref
=
self
.
generate_collectives_count_ref
(
mesh_resource
,
ln_type
,
data_shape
,
dtype
mesh_resource
,
ln_type
,
data_shape
,
dtype
,
mesh_axes
,
fp8_recipe
)
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
gamma_
=
jax
.
device_put
(
gamma
,
NamedSharding
(
mesh
,
g_pspec
))
gamma_
=
jax
.
device_put
(
gamma
,
NamedSharding
(
mesh
,
g_pspec
))
...
@@ -170,8 +214,8 @@ class TestDistributedLayernorm:
...
@@ -170,8 +214,8 @@ class TestDistributedLayernorm:
[
x_
,
gamma_
],
[
x_
,
gamma_
],
collective_count_ref
,
collective_count_ref
,
grad_args
=
(
0
,
1
),
grad_args
=
(
0
,
1
),
metric_fwd_dtype
=
dtype
,
metric_fwd_dtype
=
q_
dtype
,
metric_bwd_dtype
=
dtype
,
metric_bwd_dtype
=
q_
dtype
,
in_shardings
=
(
x_pspec
,
g_pspec
),
in_shardings
=
(
x_pspec
,
g_pspec
),
out_shardings
=
(
None
,
(
x_pspec
,
g_pspec
)),
out_shardings
=
(
None
,
(
x_pspec
,
g_pspec
)),
)
)
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
from
typing
import
Callable
,
Sequence
,
Union
,
Optional
import
pytest
import
pytest
from
typing
import
Callable
,
List
,
Sequence
,
Union
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
numpy
as
np
import
numpy
as
np
from
jax.sharding
import
Mesh
,
NamedSharding
,
PartitionSpec
from
jax.sharding
import
Mesh
,
NamedSharding
,
PartitionSpec
from
utils
import
(
assert_allclose
,
assert_tree_like_allclose
,
is_devices_enough
,
pytest_parametrize_wrapper
,
)
from
transformer_engine.
jax.fp8
import
FP8MetaPackage
,
FP8Hel
pe
r
from
transformer_engine.
common
import
reci
pe
from
transformer_engine.jax.
fp8
import
is_fp8_available
from
transformer_engine.jax.
quantize
import
is_fp8_available
,
ScalingMode
from
transformer_engine.jax
import
fp8_autocast
from
transformer_engine.jax
import
fp8_autocast
from
transformer_engine.jax.flax
import
LayerNormMLP
from
transformer_engine.jax.flax
import
LayerNormMLP
from
transformer_engine.jax.layernorm_mlp
import
fused_
layernorm_
fp8_
mlp
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
from
transformer_engine.jax.sharding
import
(
from
transformer_engine.jax.sharding
import
(
HIDDEN_AXES
,
HIDDEN_AXES
,
HIDDEN_TP_AXES
,
HIDDEN_TP_AXES
,
...
@@ -26,17 +32,25 @@ from transformer_engine.jax.sharding import (
...
@@ -26,17 +32,25 @@ from transformer_engine.jax.sharding import (
W_JOINED_AXES
,
W_JOINED_AXES
,
)
)
from
transformer_engine.jax.sharding
import
MeshResource
from
transformer_engine.jax.sharding
import
MeshResource
from
transformer_engine.jax.quantize
import
QuantizerFactory
from
utils
import
assert_allclose
,
assert_tree_like_allclose
,
is_devices_enough
is_fp8_supported
,
reason
=
is_fp8_available
()
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
if
is_mxfp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float16
]
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float16
]
INPUT_SHAPE
=
[[
64
,
128
,
32
]]
# [batch, seqlen, hidden_in]
INPUT_SHAPE
=
[[
2
,
64
,
64
]]
# [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_TP_AXES
,
HIDDEN_AXES
)
LAYERNORM_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_TP_AXES
,
HIDDEN_AXES
)
DOT_1_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_AXES
)
DOT_1_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_AXES
)
DOT_2_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_TP_AXES
)
DOT_2_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_TP_AXES
)
INTERMEDIATE
=
1
6
INTERMEDIATE
=
6
4
# Only test with FSDP and TP as DP is not used
# Only test with FSDP and TP as DP is not used
...
@@ -66,13 +80,13 @@ class TestDistributedLayernormMLP:
...
@@ -66,13 +80,13 @@ class TestDistributedLayernormMLP:
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
batch
,
seqlen
,
hidden_in
),
dtype
)
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
batch
,
seqlen
,
hidden_in
),
dtype
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
hidden_in
,),
dtype
=
dtype
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
hidden_in
,),
dtype
=
dtype
)
k1
=
jax
.
random
.
normal
(
k1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
hidden_in
,
len
(
activation_type
)
,
INTERMEDIATE
),
dtype
subkeys
[
1
],
(
hidden_in
,
len
(
activation_type
)
*
INTERMEDIATE
),
dtype
)
/
jnp
.
sqrt
(
hidden_in
)
)
/
jnp
.
sqrt
(
hidden_in
)
k2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
INTERMEDIATE
,
hidden_out
),
dtype
)
/
jnp
.
sqrt
(
k2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
INTERMEDIATE
,
hidden_out
),
dtype
)
/
jnp
.
sqrt
(
INTERMEDIATE
INTERMEDIATE
)
)
if
use_bias
:
if
use_bias
:
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
,
INTERMEDIATE
),
dtype
)
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
*
INTERMEDIATE
),
dtype
)
b2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
hidden_out
,),
dtype
)
b2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
hidden_out
,),
dtype
)
else
:
else
:
b1
=
None
b1
=
None
...
@@ -86,35 +100,13 @@ class TestDistributedLayernormMLP:
...
@@ -86,35 +100,13 @@ class TestDistributedLayernormMLP:
ln_scale
:
jnp
.
ndarray
,
ln_scale
:
jnp
.
ndarray
,
kernel_1
:
jnp
.
ndarray
,
kernel_1
:
jnp
.
ndarray
,
kernel_2
:
jnp
.
ndarray
,
kernel_2
:
jnp
.
ndarray
,
bias_1
:
jnp
.
ndarray
,
bias_1
:
Optional
[
jnp
.
ndarray
],
bias_2
:
jnp
.
ndarray
,
bias_2
:
Optional
[
jnp
.
ndarray
],
amax_list_1
:
List
[
jnp
.
ndarray
],
amax_list_2
:
List
[
jnp
.
ndarray
],
scale_list_1
:
List
[
jnp
.
ndarray
],
scale_list_2
:
List
[
jnp
.
ndarray
],
layernorm_type
:
str
=
"rmsnorm"
,
layernorm_type
:
str
=
"rmsnorm"
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
use_bias
:
bool
=
True
,
multi_gpus
:
bool
=
False
,
multi_gpus
:
bool
=
False
,
)
->
jnp
.
ndarray
:
)
->
jnp
.
ndarray
:
fp8_meta_pkg1
=
FP8MetaPackage
(
amax_list_1
[
0
],
scale_list_1
[
0
],
amax_list_1
[
1
],
scale_list_1
[
1
],
amax_list_1
[
2
],
scale_list_1
[
2
],
)
fp8_meta_pkg2
=
FP8MetaPackage
(
amax_list_2
[
0
],
scale_list_2
[
0
],
amax_list_2
[
1
],
scale_list_2
[
1
],
amax_list_2
[
2
],
scale_list_2
[
2
],
)
if
multi_gpus
:
if
multi_gpus
:
layernorm_input_axes
=
LAYERNORM_INPUT_AXES
layernorm_input_axes
=
LAYERNORM_INPUT_AXES
dot_1_input_axes
=
DOT_1_INPUT_AXES
dot_1_input_axes
=
DOT_1_INPUT_AXES
...
@@ -124,83 +116,64 @@ class TestDistributedLayernormMLP:
...
@@ -124,83 +116,64 @@ class TestDistributedLayernormMLP:
dot_1_input_axes
=
None
dot_1_input_axes
=
None
dot_2_input_axes
=
None
dot_2_input_axes
=
None
quantizer_sets
=
QuantizerFactory
.
create_set
(
n_quantizer_sets
=
2
)
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return
jnp
.
mean
(
return
jnp
.
mean
(
fused_
layernorm_
fp8_
mlp
(
layernorm_mlp
(
x
,
x
,
ln_scale
,
ln_scale
,
None
,
None
,
[
kernel_1
,
kernel_2
],
[
kernel_1
,
kernel_2
],
[
bias_1
,
bias_2
],
[
bias_1
,
bias_2
],
[
fp8_meta_pkg1
,
fp8_meta_pkg2
],
layernorm_type
,
layernorm_type
,
layer
norm_input_axes
=
layernorm_input_axes
,
norm_input_axes
=
layernorm_input_axes
,
dot_1_input_axes
=
dot_1_input_axes
,
dot_1_input_axes
=
dot_1_input_axes
,
dot_2_input_axes
=
dot_2_input_axes
,
dot_2_input_axes
=
dot_2_input_axes
,
activation_type
=
activation_type
,
activation_type
=
activation_type
,
use_bias
=
use_bia
s
,
quantizer_sets
=
quantizer_set
s
,
)
)
)
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
.
mark
.
parametrize
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm_fp8_mlp_primitive
(
def
test_layernorm_fp8_mlp_primitive
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
):
):
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
layernorm_type
=
"rmsnorm"
layernorm_type
=
"rmsnorm"
fp8_amax_list_1
=
[
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
]
fp8_amax_list_2
=
[
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
]
fp8_scale_list_1
=
[
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
]
fp8_scale_list_2
=
[
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
]
inputs
=
[
x
,
gamma
,
k1
,
k2
,
b1
,
b2
]
=
self
.
generate_inputs
(
inputs
=
[
x
,
gamma
,
k1
,
k2
,
b1
,
b2
]
=
self
.
generate_inputs
(
input_shape
,
activation_type
,
use_bias
,
dtype
input_shape
,
activation_type
,
use_bias
,
dtype
)
)
inputs
=
[
*
inputs
,
fp8_amax_list_1
,
fp8_amax_list_2
,
fp8_scale_list_1
,
fp8_scale_list_2
]
static_inputs
=
[
layernorm_type
,
activation_type
]
static_inputs
=
[
layernorm_type
,
activation_type
,
use_bias
]
value_and_grad_func
=
jax
.
value_and_grad
(
value_and_grad_func
=
jax
.
value_and_grad
(
self
.
layernorm_fp8_mlp_prim_func
,
argnums
=
range
(
len
(
inputs
))
self
.
layernorm_fp8_mlp_prim_func
,
argnums
=
range
(
len
(
inputs
))
)
)
# Single GPU
# Single GPU
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
):
single_jitter
=
jax
.
jit
(
single_jitter
=
jax
.
jit
(
value_and_grad_func
,
static_argnums
=
range
(
len
(
inputs
),
len
(
static_inputs
)
+
len
(
inputs
))
value_and_grad_func
,
static_argnums
=
range
(
len
(
inputs
),
len
(
static_inputs
)
+
len
(
inputs
)),
)
)
with
fp8_autocast
(
enabled
=
True
):
single_fwd
,
single_grads
=
single_jitter
(
*
inputs
,
*
static_inputs
)
single_fwd
,
single_grads
=
single_jitter
(
*
inputs
,
*
static_inputs
)
# Multi GPUs
# Multi GPUs
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
True
,
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
None
,
"tp"
))
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
"tp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
k2_
=
jax
.
device_put
(
k2
,
k2_sharding
)
k2_
=
jax
.
device_put
(
k2
,
k2_sharding
)
if
use_bias
:
if
use_bias
:
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
"tp"
))
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
))
b1_
=
jax
.
device_put
(
b1
,
b1_sharding
)
b1_
=
jax
.
device_put
(
b1
,
b1_sharding
)
else
:
else
:
b1_sharding
=
b1_
=
None
b1_sharding
=
b1_
=
None
...
@@ -208,7 +181,7 @@ class TestDistributedLayernormMLP:
...
@@ -208,7 +181,7 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# x, gamma, k1, k2, b1,
# b2
, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
# b2
in_shardings
=
(
in_shardings
=
(
None
,
None
,
None
,
None
,
...
@@ -216,14 +189,10 @@ class TestDistributedLayernormMLP:
...
@@ -216,14 +189,10 @@ class TestDistributedLayernormMLP:
k2_sharding
,
k2_sharding
,
b1_sharding
,
b1_sharding
,
None
,
None
,
None
,
None
,
None
,
None
,
)
)
out_shardings
=
(
out_shardings
=
(
None
,
None
,
(
None
,
None
,
k1_sharding
,
k2_sharding
,
b1_sharding
,
None
,
None
,
None
,
None
,
None
),
(
None
,
None
,
k1_sharding
,
k2_sharding
,
b1_sharding
,
None
),
)
)
multi_jitter
=
jax
.
jit
(
multi_jitter
=
jax
.
jit
(
...
@@ -245,15 +214,42 @@ class TestDistributedLayernormMLP:
...
@@ -245,15 +214,42 @@ class TestDistributedLayernormMLP:
m_grad
,
s_grad
,
dtype
=
dtype
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
m_grad
,
s_grad
,
dtype
=
dtype
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
)
)
else
:
else
:
is_gated
=
len
(
activation_type
)
>
1
rtol
=
None
atol
=
None
if
is_gated
:
if
dtype
==
jnp
.
bfloat16
:
if
i
==
2
:
rtol
=
800
atol
=
9e-2
if
i
==
4
:
atol
=
300
rtol
=
1e-1
if
dtype
==
jnp
.
float16
:
if
i
==
1
:
# gamma
rtol
=
200
atol
=
1e-2
if
i
==
2
:
rtol
=
2000
atol
=
7e-2
if
i
==
4
and
fp8_recipe
==
recipe
.
MXFP8BlockScaling
():
# bias_1
# Accumulating dbias across a large tensor introduces a larger difference
rtol
=
200
atol
=
4e-2
if
i
==
4
and
fp8_recipe
==
recipe
.
DelayedScaling
():
rtol
=
2200
atol
=
9e-2
assert_allclose
(
assert_allclose
(
multi_grads
[
i
],
multi_grads
[
i
],
single_grads
[
i
],
single_grads
[
i
],
dtype
=
dtype
,
dtype
=
dtype
,
rtol
=
rtol
,
atol
=
atol
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
)
def
_test_layernorm_mlp
(
def
_test_layernorm_mlp
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
,
fp8_recipe
=
None
):
):
batch
,
seqlen
,
hidden_in
=
input_shape
batch
,
seqlen
,
hidden_in
=
input_shape
layernorm_type
=
"rmsnorm"
layernorm_type
=
"rmsnorm"
...
@@ -265,7 +261,7 @@ class TestDistributedLayernormMLP:
...
@@ -265,7 +261,7 @@ class TestDistributedLayernormMLP:
init_rngs
=
{
"params"
:
subkeys
[
1
]}
init_rngs
=
{
"params"
:
subkeys
[
1
]}
# Single GPUs
# Single GPUs
with
fp8_autocast
(
enabled
=
use_fp8
):
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
ln_mlp_single
=
LayerNormMLP
(
ln_mlp_single
=
LayerNormMLP
(
layernorm_type
=
layernorm_type
,
layernorm_type
=
layernorm_type
,
transpose_batch_sequence
=
False
,
# input: [batch, seqlen, hidden]
transpose_batch_sequence
=
False
,
# input: [batch, seqlen, hidden]
...
@@ -282,7 +278,9 @@ class TestDistributedLayernormMLP:
...
@@ -282,7 +278,9 @@ class TestDistributedLayernormMLP:
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
use_fp8
,
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
ln_mlp_sharded
=
LayerNormMLP
(
ln_mlp_sharded
=
LayerNormMLP
(
layernorm_type
=
layernorm_type
,
layernorm_type
=
layernorm_type
,
transpose_batch_sequence
=
False
,
transpose_batch_sequence
=
False
,
...
@@ -310,25 +308,30 @@ class TestDistributedLayernormMLP:
...
@@ -310,25 +308,30 @@ class TestDistributedLayernormMLP:
assert_allclose
(
ln_out_sharded
,
ln_out_single
,
dtype
=
dtype
)
assert_allclose
(
ln_out_sharded
,
ln_out_single
,
dtype
=
dtype
)
assert_allclose
(
mlp_out_sharded
,
mlp_out_single
,
dtype
=
dtype
)
assert_allclose
(
mlp_out_sharded
,
mlp_out_single
,
dtype
=
dtype
)
@
pytest
.
mark
.
parametrize
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
_
parametrize
_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
_
parametrize
_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)
,
(
"gelu"
,
"gelu"
)
])
@
pytest
_
parametrize
_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
_
parametrize
_wrapper
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
_
parametrize
_wrapper
(
"use_bias"
,
[
True
,
False
])
def
test_layernorm_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
):
def
test_layernorm_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
):
self
.
_test_layernorm_mlp
(
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
False
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
False
)
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
# TODO: debug
@
pytest
.
mark
.
parametrize
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
# @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
),
(
"gelu"
,
"gelu"
)])
# @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
# @pytest_parametrize_wrapper(
@
pytest
.
mark
.
parametrize
(
"input_shape"
,
INPUT_SHAPE
)
# "activation_type", [("gelu",), ("gelu", "linear")]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
# )
def
test_layernorm_fp8_mlp_layer
(
# @pytest_parametrize_wrapper("use_bias", [True, False])
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
# @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
):
# @pytest_parametrize_wrapper("dtype", DTYPES)
self
.
_test_layernorm_mlp
(
# @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
True
# def test_layernorm_fp8_mlp_layer(
)
# self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
# ):
# self._test_layernorm_mlp(
# mesh_config, activation_type, use_bias, input_shape, dtype,
# use_fp8=True, fp8_recipe=fp8_recipe
# )
tests/jax/test_distributed_softmax.py
View file @
a207db1d
...
@@ -3,8 +3,8 @@
...
@@ -3,8 +3,8 @@
# See LICENSE for license information.
# See LICENSE for license information.
import
warnings
import
warnings
import
pytest
from
functools
import
partial
from
functools
import
partial
import
pytest
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
...
...
tests/jax/test_helper.py
View file @
a207db1d
...
@@ -13,13 +13,13 @@ from utils import assert_allclose
...
@@ -13,13 +13,13 @@ from utils import assert_allclose
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.common.recipe
import
Format
as
FP8Format
from
transformer_engine.common.recipe
import
Format
as
FP8Format
from
transformer_engine.jax
import
fp8_autocast
,
get_delayed_scaling
from
transformer_engine.jax
import
fp8_autocast
,
get_delayed_scaling
from
transformer_engine.jax.
fp8
import
FP8Helper
,
is_fp8_available
,
AmaxComputeAlgo
from
transformer_engine.jax.
quantize
import
QuantizeConfig
,
is_fp8_available
,
AmaxComputeAlgo
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
is_fp8_supported
,
reason
=
is_fp8_available
()
is_fp8_supported
,
reason
=
is_fp8_available
()
class
Test
FP8Helper
(
unittest
.
TestCase
):
class
Test
QuantizeConfig
(
unittest
.
TestCase
):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_initialize
(
self
):
def
test_initialize
(
self
):
...
@@ -27,30 +27,30 @@ class TestFP8Helper(unittest.TestCase):
...
@@ -27,30 +27,30 @@ class TestFP8Helper(unittest.TestCase):
fp8_format
=
FP8Format
.
E4M3
fp8_format
=
FP8Format
.
E4M3
amax_history_len
=
10
amax_history_len
=
10
FP8Helper
.
initialize
(
QuantizeConfig
.
initialize
(
margin
=
margin
,
fp8_format
=
fp8_format
,
amax_history_len
=
amax_history_len
margin
=
margin
,
fp8_format
=
fp8_format
,
amax_history_len
=
amax_history_len
)
)
self
.
assertEqual
(
self
.
assertEqual
(
FP8Helper
.
MARGIN
,
QuantizeConfig
.
MARGIN
,
margin
,
margin
,
f
"
FP8Helper
.MARGIN initialization failed, should be
{
margin
}
"
f
"
QuantizeConfig
.MARGIN initialization failed, should be
{
margin
}
"
f
" but got
{
FP8Helper
.
MARGIN
}
."
,
f
" but got
{
QuantizeConfig
.
MARGIN
}
."
,
)
)
self
.
assertEqual
(
self
.
assertEqual
(
FP8Helper
.
FP8_FORMAT
,
QuantizeConfig
.
FP8_FORMAT
,
fp8_format
,
fp8_format
,
f
"
FP8Helper
.FP8_FORMAT initialization failed, should be
{
fp8_format
}
"
f
"
QuantizeConfig
.FP8_FORMAT initialization failed, should be
{
fp8_format
}
"
f
" but got
{
FP8Helper
.
FP8_FORMAT
}
."
,
f
" but got
{
QuantizeConfig
.
FP8_FORMAT
}
."
,
)
)
self
.
assertEqual
(
self
.
assertEqual
(
FP8Helper
.
AMAX_HISTORY_LEN
,
QuantizeConfig
.
AMAX_HISTORY_LEN
,
amax_history_len
,
amax_history_len
,
f
"
FP8Helper
.AMAX_HISTORY_LEN initialization failed, should be
{
amax_history_len
}
"
f
"
QuantizeConfig
.AMAX_HISTORY_LEN initialization failed, should be
{
amax_history_len
}
"
f
" but got
{
FP8Helper
.
AMAX_HISTORY_LEN
}
."
,
f
" but got
{
QuantizeConfig
.
AMAX_HISTORY_LEN
}
."
,
)
)
FP8Helper
.
finalize
()
QuantizeConfig
.
finalize
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_update_collections
(
self
):
def
test_update_collections
(
self
):
...
@@ -61,12 +61,12 @@ class TestFP8Helper(unittest.TestCase):
...
@@ -61,12 +61,12 @@ class TestFP8Helper(unittest.TestCase):
"test1"
:
original_val
,
"test1"
:
original_val
,
"test2"
:
original_val
,
"test2"
:
original_val
,
}
}
updated_state
=
FP8Helper
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
updated_state
=
QuantizeConfig
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
original_state
=
flax
.
core
.
frozen_dict
.
FrozenDict
(
original_state
)
original_state
=
flax
.
core
.
frozen_dict
.
FrozenDict
(
original_state
)
updated_state
=
FP8Helper
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
updated_state
=
QuantizeConfig
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
...
@@ -74,7 +74,7 @@ class TestFP8Helper(unittest.TestCase):
...
@@ -74,7 +74,7 @@ class TestFP8Helper(unittest.TestCase):
class
TestFP8Functions
(
unittest
.
TestCase
):
class
TestFP8Functions
(
unittest
.
TestCase
):
def
_check_defult_state
(
self
):
def
_check_defult_state
(
self
):
self
.
assertFalse
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
def
_compare_delay_scaling
(
self
,
ref
,
test
):
def
_compare_delay_scaling
(
self
,
ref
,
test
):
self
.
assertTrue
(
ref
.
margin
==
test
.
margin
)
self
.
assertTrue
(
ref
.
margin
==
test
.
margin
)
...
@@ -84,32 +84,32 @@ class TestFP8Functions(unittest.TestCase):
...
@@ -84,32 +84,32 @@ class TestFP8Functions(unittest.TestCase):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast
(
self
):
def
test_fp8_autocast
(
self
):
FP8Helper
.
finalize
()
# Ensure the testing not affect by previous tests.
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_defult_state
()
self
.
_check_defult_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
DelayedScaling
()):
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
DelayedScaling
()):
self
.
assertFalse
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
DelayedScaling
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
DelayedScaling
())
self
.
_check_defult_state
()
self
.
_check_defult_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_defult_state
()
self
.
_check_defult_state
()
ds
=
DelayedScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
,
amax_history_len
=
1
)
ds
=
DelayedScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_defult_state
()
self
.
_check_defult_state
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast_with_sharding_resource
(
self
):
def
test_fp8_autocast_with_sharding_resource
(
self
):
FP8Helper
.
finalize
()
# Ensure the testing not affect by previous tests.
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_defult_state
()
self
.
_check_defult_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
...
@@ -126,7 +126,7 @@ class TestFP8Functions(unittest.TestCase):
...
@@ -126,7 +126,7 @@ class TestFP8Functions(unittest.TestCase):
with
jax
.
sharding
.
Mesh
(
devices
,
(
"dp"
,
"tp"
)):
with
jax
.
sharding
.
Mesh
(
devices
,
(
"dp"
,
"tp"
)):
for
sr
in
mesh_s
:
for
sr
in
mesh_s
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
,
mesh_resource
=
sr
):
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
,
mesh_resource
=
sr
):
self
.
assertTrue
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
assertEqual
(
sr
,
global_mesh_resource
())
self
.
assertEqual
(
sr
,
global_mesh_resource
())
...
...
tests/jax/test_layer.py
View file @
a207db1d
...
@@ -20,11 +20,14 @@ from utils import (
...
@@ -20,11 +20,14 @@ from utils import (
from
utils
import
DecoderLayer
as
RefDecoderLayer
from
utils
import
DecoderLayer
as
RefDecoderLayer
from
utils
import
EncoderLayer
as
RefEncoderLayer
from
utils
import
EncoderLayer
as
RefEncoderLayer
from
transformer_engine.common
.recipe
import
Format
from
transformer_engine.common
import
recipe
from
transformer_engine.jax.flax
import
TransformerLayer
,
TransformerLayerType
from
transformer_engine.jax.flax
import
TransformerLayer
,
TransformerLayerType
from
transformer_engine.jax.fp8
import
FP8Helper
,
is_fp8_available
from
transformer_engine.jax.quantize
import
(
QuantizeConfig
,
is_fp8_supported
,
reason
=
is_fp8_available
()
ScalingMode
,
is_fp8_available
,
update_collections
,
)
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"function"
)
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"function"
)
...
@@ -35,12 +38,21 @@ def enable_fused_attn():
...
@@ -35,12 +38,21 @@ def enable_fused_attn():
del
os
.
environ
[
"NVTE_FUSED_ATTN"
]
del
os
.
environ
[
"NVTE_FUSED_ATTN"
]
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
QUANTIZE_RECIPES
=
[]
""" Find supported scaling modes"""
if
is_fp8_supported
:
QUANTIZE_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
if
is_mxfp8_supported
:
QUANTIZE_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
DATA_SHAPE
=
[
# (batch, seqlen, emb_dim)
DATA_SHAPE
=
[
# (batch, seqlen, emb_dim)
pytest
.
param
((
32
,
128
,
1024
),
id
=
"32-128-1024"
),
pytest
.
param
((
32
,
128
,
1024
),
id
=
"32-128-1024"
),
pytest
.
param
((
32
,
512
,
1024
),
id
=
"32-512-1024"
),
]
]
DTYPE
=
[
jnp
.
float32
,
jnp
.
bfloat16
]
DTYPE
=
[
jnp
.
bfloat16
]
FP8_FORMATS
=
[
Format
.
E4M3
,
Format
.
HYBRID
]
_KEY_OF_RESIDUAL_POST_LAYERNORM
=
"apply_residual_connection_post_layernorm"
_KEY_OF_RESIDUAL_POST_LAYERNORM
=
"apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM
=
"output_layernorm"
_KEY_OF_OUTPUT_LAYERNORM
=
"output_layernorm"
...
@@ -80,27 +92,37 @@ BASE_ATTRS = {
...
@@ -80,27 +92,37 @@ BASE_ATTRS = {
}
}
ATTRS
=
[
ATTRS
=
[
# attrs0
{},
{},
# attrs1
{
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
},
},
# attrs2
{
{
_KEY_OF_ZERO_CENTERED_GAMMA
:
True
,
_KEY_OF_ZERO_CENTERED_GAMMA
:
True
,
_KEY_OF_LAYERNORM_EPS
:
1e-2
,
_KEY_OF_LAYERNORM_EPS
:
1e-2
,
},
},
# attrs3
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_RESIDUAL_POST_LAYERNORM
:
True
},
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_RESIDUAL_POST_LAYERNORM
:
True
},
# attrs4
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_OUTPUT_LAYERNORM
:
True
},
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_OUTPUT_LAYERNORM
:
True
},
# attrs5
{
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_RESIDUAL_POST_LAYERNORM
:
True
,
_KEY_OF_RESIDUAL_POST_LAYERNORM
:
True
,
_KEY_OF_OUTPUT_LAYERNORM
:
True
,
_KEY_OF_OUTPUT_LAYERNORM
:
True
,
},
},
# attrs6
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_DROP_PATH
:
0.1
},
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_DROP_PATH
:
0.1
},
# attrs7
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_FUSE_QKV_PARAMS
:
False
},
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_FUSE_QKV_PARAMS
:
False
},
# attrs8
{
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
},
},
# attrs9
{
{
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
...
@@ -109,12 +131,14 @@ ATTRS = [
...
@@ -109,12 +131,14 @@ ATTRS = [
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
_KEY_OF_USE_BIAS
:
True
,
_KEY_OF_USE_BIAS
:
True
,
},
},
# attrs10
{
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
},
},
# attrs11
{
{
_KEY_OF_NUM_HEADS
:
8
,
_KEY_OF_NUM_HEADS
:
8
,
_KEY_OF_NUM_GQA_GROUPS
:
4
,
_KEY_OF_NUM_GQA_GROUPS
:
4
,
...
@@ -123,33 +147,7 @@ ATTRS = [
...
@@ -123,33 +147,7 @@ ATTRS = [
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,),
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,),
_KEY_OF_USE_BIAS
:
True
,
_KEY_OF_USE_BIAS
:
True
,
},
},
{
# attrs12
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,
"linear"
)),
},
{
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_HIDDEN_DROPOUT
:
0.8
,
_KEY_OF_INTERMEDIATE_DROPOUT
:
0.5
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,
"linear"
)),
_KEY_OF_USE_BIAS
:
True
,
},
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,
"linear"
)),
},
{
_KEY_OF_NUM_HEADS
:
8
,
_KEY_OF_NUM_GQA_GROUPS
:
4
,
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"layernorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,)),
_KEY_OF_USE_BIAS
:
True
,
},
{
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
...
@@ -158,12 +156,14 @@ ATTRS = [
...
@@ -158,12 +156,14 @@ ATTRS = [
_KEY_OF_ROPE_GROUP_METHOD
:
"consecutive"
,
_KEY_OF_ROPE_GROUP_METHOD
:
"consecutive"
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
},
# attrs13
{
{
_KEY_OF_TRANSPOSE_BS
:
True
,
_KEY_OF_TRANSPOSE_BS
:
True
,
_KEY_OF_ENABLE_ROPE
:
True
,
_KEY_OF_ENABLE_ROPE
:
True
,
_KEY_OF_ROPE_GROUP_METHOD
:
"consecutive"
,
_KEY_OF_ROPE_GROUP_METHOD
:
"consecutive"
,
_KEY_OF_USE_BIAS
:
True
,
_KEY_OF_USE_BIAS
:
True
,
},
},
# attrs14
{
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_LAYERNORM_TYPE
:
"layernorm"
,
_KEY_OF_LAYERNORM_TYPE
:
"layernorm"
,
...
@@ -173,6 +173,7 @@ ATTRS = [
...
@@ -173,6 +173,7 @@ ATTRS = [
_KEY_OF_USE_BIAS
:
True
,
_KEY_OF_USE_BIAS
:
True
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
},
# attrs15
{
{
_KEY_OF_TRANSPOSE_BS
:
True
,
_KEY_OF_TRANSPOSE_BS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
...
@@ -180,26 +181,32 @@ ATTRS = [
...
@@ -180,26 +181,32 @@ ATTRS = [
_KEY_OF_ROPE_GROUP_METHOD
:
"alternate"
,
_KEY_OF_ROPE_GROUP_METHOD
:
"alternate"
,
_KEY_OF_USE_BIAS
:
True
,
_KEY_OF_USE_BIAS
:
True
,
},
},
# attrs16
{
{
_KEY_OF_HIDDEN_DROPOUT
:
0.3
,
_KEY_OF_HIDDEN_DROPOUT
:
0.3
,
_KEY_OF_HIDDEN_DROPOUT_DIMS
:
(
0
,),
_KEY_OF_HIDDEN_DROPOUT_DIMS
:
(
0
,),
_KEY_OF_INTERMEDIATE_DROPOUT
:
0.5
,
_KEY_OF_INTERMEDIATE_DROPOUT
:
0.5
,
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS
:
(
1
,),
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS
:
(
1
,),
},
},
# attrs17
{
{
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"padding"
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"padding"
,
_KEY_OF_USE_BIAS
:
True
,
_KEY_OF_USE_BIAS
:
True
,
},
},
# attrs18
{
{
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_BIAS_TYPE
:
"no_bias"
,
_KEY_OF_SELF_ATTN_BIAS_TYPE
:
"no_bias"
,
},
},
# attrs19
{
{
_KEY_OF_ATTENTION_DROPOUT
:
0.3
,
_KEY_OF_ATTENTION_DROPOUT
:
0.3
,
},
},
# attrs20
{
{
_KEY_OF_MLP_ACTIVATIONS
:
((
"relu"
,
"relu"
)),
_KEY_OF_MLP_ACTIVATIONS
:
((
"relu"
,
"relu"
)),
},
},
# attrs21
{
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
...
@@ -207,6 +214,7 @@ ATTRS = [
...
@@ -207,6 +214,7 @@ ATTRS = [
_KEY_OF_WINDOW_SIZE
:
(
64
,
0
),
# Left size must < DATA_SHAPE seqlen
_KEY_OF_WINDOW_SIZE
:
(
64
,
0
),
# Left size must < DATA_SHAPE seqlen
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
},
# attrs22
{
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
...
@@ -296,20 +304,24 @@ class BaseRunner:
...
@@ -296,20 +304,24 @@ class BaseRunner:
ref_params
,
test_params
=
self
.
_sync_params
(
ref_params
,
test_params
)
ref_params
,
test_params
=
self
.
_sync_params
(
ref_params
,
test_params
)
if
FP8Helper
.
is_fp8_enabled
():
if
QuantizeConfig
.
is_fp8_enabled
():
for
_
in
range
(
4
):
for
_
in
range
(
4
):
_
,
tmp_grad
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
3
,),
has_aux
=
False
)(
_
,
updated_state
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
3
,),
has_aux
=
False
)(
inputs
,
inputs
,
test_masks
,
test_masks
,
test_params
,
test_params
,
test_others
,
test_others
,
test_layer
,
test_layer
,
)
)
_
,
fp8_meta_grad
=
flax
.
core
.
pop
(
tmp_grad
[
0
],
FP8Helper
.
FP8_COLLECTION_NAME
)
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
test_others
=
FP8Helper
.
update_collections
(
_
,
updated_quantize_meta
=
flax
.
core
.
pop
(
{
FP8Helper
.
FP8_COLLECTION_NAME
:
fp8_meta_grad
},
test_others
updated_state
[
0
],
QuantizeConfig
.
COLLECTION_NAME
)
test_others
=
update_collections
(
{
QuantizeConfig
.
COLLECTION_NAME
:
updated_quantize_meta
},
test_others
)
)
del
tmp_grad
,
fp8_meta_grad
del
updated_quantize_meta
del
updated_state
grad_fn
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
0
,
2
),
has_aux
=
False
)
grad_fn
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
0
,
2
),
has_aux
=
False
)
...
@@ -436,29 +448,29 @@ class BaseTester:
...
@@ -436,29 +448,29 @@ class BaseTester:
def
test_forward
(
self
,
data_shape
,
dtype
,
attrs
):
def
test_forward
(
self
,
data_shape
,
dtype
,
attrs
):
"""Test normal datatype forward"""
"""Test normal datatype forward"""
FP8Helper
.
finalize
()
# Ensure FP8 disabled.
QuantizeConfig
.
finalize
()
# Ensure FP8 disabled.
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
)
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
)
def
test_backward
(
self
,
data_shape
,
dtype
,
attrs
):
def
test_backward
(
self
,
data_shape
,
dtype
,
attrs
):
"""Test normal datatype backward"""
"""Test normal datatype backward"""
FP8Helper
.
finalize
()
# Ensure FP8 disabled.
QuantizeConfig
.
finalize
()
# Ensure FP8 disabled.
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
)
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fp8_
format"
,
FP8_FORMAT
S
)
@
pytest
.
mark
.
parametrize
(
"fp8_
recipe"
,
QUANTIZE_RECIPE
S
)
def
test_forward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
format
):
def
test_forward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
recipe
):
"""Test forward with fp8 enabled"""
"""Test forward with fp8 enabled"""
FP8Helper
.
initialize
(
fp8_
format
=
fp8_format
)
QuantizeConfig
.
initialize
(
fp8_
recipe
=
fp8_recipe
)
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
FP8Helper
.
finalize
()
QuantizeConfig
.
finalize
()
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fp8_
format"
,
FP8_FORMAT
S
)
@
pytest
.
mark
.
parametrize
(
"fp8_
recipe"
,
QUANTIZE_RECIPE
S
)
def
test_backward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
format
):
def
test_backward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
recipe
):
"""Test backward with fp8 enabled"""
"""Test backward with fp8 enabled"""
FP8Helper
.
initialize
(
fp8_
format
=
fp8_format
)
QuantizeConfig
.
initialize
(
fp8_
recipe
=
fp8_recipe
)
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
FP8Helper
.
finalize
()
QuantizeConfig
.
finalize
()
class
TestEncoderLayer
(
BaseTester
):
class
TestEncoderLayer
(
BaseTester
):
...
...
tests/jax/test_praxis_layers.py
deleted
100644 → 0
View file @
fbee8990
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
os
from
functools
import
partial
from
typing
import
Dict
,
Tuple
import
flax
import
jax
import
jax.numpy
as
jnp
from
praxis
import
pax_fiddle
from
praxis.base_layer
import
WeightInit
,
DEFAULT_INIT_MUTABLE_LIST
import
pytest
from
utils
import
assert_allclose
from
transformer_engine.common.recipe
import
DelayedScaling
,
Format
from
transformer_engine.jax
import
fp8_autocast
,
update_collections
from
transformer_engine.jax.flax
import
DenseGeneral
,
LayerNormDenseGeneral
from
transformer_engine.jax.flax
import
LayerNorm
as
flax_LayerNorm
from
transformer_engine.jax.flax
import
LayerNormMLP
as
flax_LayerNormMLP
from
transformer_engine.jax.flax
import
MultiHeadAttention
as
flax_MultiHeadAttention
from
transformer_engine.jax.flax
import
DotProductAttention
as
flax_DotProductAttention
from
transformer_engine.jax.flax
import
RelativePositionBiases
as
flax_RelativePositionBiases
from
transformer_engine.jax.flax
import
TransformerLayer
as
flax_TransformerLayer
from
transformer_engine.jax.flax.module
import
Softmax
from
transformer_engine.jax.fp8
import
FP8Helper
,
is_fp8_available
from
transformer_engine.jax.praxis
import
LayerNorm
from
transformer_engine.jax.praxis
import
FusedSoftmax
from
transformer_engine.jax.praxis
import
LayerNormLinear
,
LayerNormMLP
,
Linear
from
transformer_engine.jax.praxis
import
DotProductAttention
,
MultiHeadAttention
from
transformer_engine.jax.praxis
import
RelativePositionBiases
,
TransformerEngineBaseLayer
from
transformer_engine.jax.praxis
import
TransformerLayer
,
TransformerLayerType
from
transformer_engine.jax.softmax
import
SoftmaxType
is_fp8_supported
,
reason
=
is_fp8_available
()
DATA_SHAPE
=
[(
32
,
128
,
512
),
(
32
,
512
,
512
)]
# (B, S, H)
DTYPE
=
[
jnp
.
float32
,
jnp
.
bfloat16
]
ENABLE_FP8
=
[
False
,
True
]
FP8_FORMATS
=
[
Format
.
E4M3
,
Format
.
HYBRID
]
def
compare_dict
(
ref_fd
,
test_fd
,
rtol
=
1e-05
,
atol
=
1e-08
):
for
key
in
ref_fd
:
assert
key
in
test_fd
,
f
"
{
key
}
not found in test dict
{
test_fd
}
"
assert
isinstance
(
test_fd
[
key
],
type
(
ref_fd
[
key
])
),
f
"The data type is not match between ref and test Dict on
{
key
=
}
"
if
isinstance
(
ref_fd
[
key
],
Dict
):
compare_dict
(
ref_fd
[
key
],
test_fd
[
key
],
rtol
,
atol
)
else
:
assert_allclose
(
ref_fd
[
key
],
test_fd
[
key
],
rtol
=
rtol
,
atol
=
atol
,
err_msg
=
f
"
{
key
=
}
is not close"
)
class
TestLayer
:
@
staticmethod
def
loss
(
inner_variables
,
*
inner_inputs
,
module
,
mean_out
=
True
):
outs
=
module
.
apply
(
inner_variables
,
*
inner_inputs
)
out
=
outs
if
isinstance
(
outs
,
tuple
):
# The first place of outs is the real output, others
# are auxiliary values.
out
=
outs
[
0
]
return
jnp
.
mean
(
out
)
if
mean_out
else
out
@
staticmethod
def
loss_and_grads
(
module
,
variables
,
*
inputs
):
grad_fn
=
jax
.
value_and_grad
(
TestLayer
.
loss
,
argnums
=
(
0
,
1
))
loss_val
,
(
wgrads
,
dgrad
)
=
grad_fn
(
variables
,
*
inputs
,
module
=
module
)
return
loss_val
,
wgrads
,
dgrad
def
input_getter
(
self
,
shape
,
dtype
):
raise
NotImplementedError
def
get_layer_name
(
self
):
raise
NotImplementedError
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
raise
NotImplementedError
def
sync_variables
(
self
,
praxis_variables
,
flax_variables
):
synced_praxis_variables
=
praxis_variables
lyr_name
=
self
.
get_layer_name
()
if
"params"
in
flax_variables
:
synced_praxis_variables
[
"params"
][
lyr_name
][
"cld"
]
=
flax
.
core
.
unfreeze
(
flax_variables
[
"params"
]
)
return
synced_praxis_variables
,
flax_variables
def
sync_wgrads
(
self
,
praxis_wgrads
,
flax_wgrads
):
synced_praxis_grads
=
praxis_wgrads
lyr_name
=
self
.
get_layer_name
()
if
"params"
in
synced_praxis_grads
:
synced_praxis_grads
[
"params"
]
=
synced_praxis_grads
[
"params"
][
lyr_name
][
"cld"
]
if
FP8Helper
.
is_fp8_enabled
():
synced_praxis_grads
[
FP8Helper
.
FP8_COLLECTION_NAME
]
=
synced_praxis_grads
[
FP8Helper
.
FP8_COLLECTION_NAME
][
lyr_name
][
"cld"
]
return
synced_praxis_grads
,
flax
.
core
.
unfreeze
(
flax_wgrads
)
def
forward_backward_runner
(
self
,
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
=
1e-05
,
atol
=
1e-08
):
init_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
test_inputs
=
self
.
input_getter
(
data_shape
,
dtype
)
praxis_layer
=
praxis_p
.
Instantiate
()
# This is a workaround to correctly enable FP8 meta generation for Praxis.
# TODO (Ming Huang): To come out a better solution.
mutable_list
=
DEFAULT_INIT_MUTABLE_LIST
+
[
FP8Helper
.
FP8_COLLECTION_NAME
]
praxis_variables
=
praxis_layer
.
init
(
init_key
,
*
test_inputs
,
mutable
=
mutable_list
)
flax_layer
=
flax_cls
()
flax_variables
=
flax_layer
.
init
(
init_key
,
*
test_inputs
)
if
"params_axes"
in
flax_variables
:
flax_variables
,
_
=
flax
.
core
.
pop
(
flax_variables
,
"params_axes"
)
if
FP8Helper
.
is_fp8_enabled
():
flax_variables
,
_
=
flax
.
core
.
pop
(
flax_variables
,
FP8Helper
.
FP8_COLLECTION_NAME
+
"_axes"
)
praxis_variables
,
flax_variables
=
self
.
sync_variables
(
praxis_variables
,
flax_variables
)
iter_times
=
5
if
FP8Helper
.
is_fp8_enabled
()
else
1
for
_
in
range
(
iter_times
):
praxis_loss
,
praxis_wgrads
,
praxis_dgrad
=
TestLayer
.
loss_and_grads
(
praxis_layer
,
praxis_variables
,
*
test_inputs
)
flax_loss
,
flax_wgrads
,
flax_dgrad
=
TestLayer
.
loss_and_grads
(
flax_layer
,
flax_variables
,
*
test_inputs
)
if
FP8Helper
.
is_fp8_enabled
():
praxis_wgrads
.
pop
(
"params"
)
praxis_variables
=
update_collections
(
praxis_wgrads
,
praxis_variables
)
flax_wgrads
,
_
=
flax
.
core
.
pop
(
flax_wgrads
,
"params"
)
flax_variables
=
update_collections
(
flax_wgrads
,
flax_variables
)
praxis_loss
,
praxis_wgrads
,
praxis_dgrad
=
TestLayer
.
loss_and_grads
(
praxis_layer
,
praxis_variables
,
*
test_inputs
)
flax_loss
,
flax_wgrads
,
flax_dgrad
=
TestLayer
.
loss_and_grads
(
flax_layer
,
flax_variables
,
*
test_inputs
)
assert_allclose
(
praxis_loss
,
flax_loss
,
rtol
=
rtol
,
atol
=
atol
)
assert_allclose
(
praxis_dgrad
,
flax_dgrad
,
rtol
=
rtol
,
atol
=
atol
)
praxis_wgrads
,
flax_wgrads
=
self
.
sync_wgrads
(
praxis_wgrads
,
flax_wgrads
)
compare_dict
(
praxis_wgrads
,
flax_wgrads
,
rtol
=
rtol
,
atol
=
atol
)
class
LayerNormAttr
:
LN_TYPE
=
"layernorm_type"
ZERO_CEN
=
"zero_centered_gamma"
ATTRS
=
[
{
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
},
{
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
},
{
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
},
]
class
TestLayerNorm
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
data_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
return
(
jax
.
random
.
normal
(
data_key
,
shape
,
dtype
),)
def
get_layer_name
(
self
):
return
"layer_norm"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
layernorm_type
=
attrs
[
LayerNormAttr
.
LN_TYPE
]
zero_centered_gamma
=
attrs
[
LayerNormAttr
.
ZERO_CEN
]
scale_init
=
None
bias_init
=
WeightInit
.
Constant
(
0.0
)
transpose_batch_sequence
=
False
praxis_p
=
pax_fiddle
.
Config
(
LayerNorm
,
name
=
"layer_norm"
,
dtype
=
dtype
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
scale_init
=
scale_init
,
bias_init
=
bias_init
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
flax_cls
=
partial
(
flax_LayerNorm
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
scale_init
=
scale_init
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"ln_bias"
,
bias_init
),
dtype
=
dtype
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LayerNormAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
FusedSoftmaxAttr
:
SCALE_FACTOR
=
"scale_factor"
ST_TYPE
=
"softmax_type"
ATTRS
=
[
{
SCALE_FACTOR
:
0.0
,
ST_TYPE
:
SoftmaxType
.
SCALED
},
{
SCALE_FACTOR
:
0.0
,
ST_TYPE
:
SoftmaxType
.
SCALED_MASKED
},
{
SCALE_FACTOR
:
0.0
,
ST_TYPE
:
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
},
]
class
TestFusedSoftmax
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
data_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
return
jax
.
random
.
normal
(
data_key
,
shape
,
dtype
),
jnp
.
ones
(
shape
,
dtype
=
jnp
.
uint8
)
# Masks
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
scale_factor
=
attrs
[
FusedSoftmaxAttr
.
SCALE_FACTOR
]
softmax_type
=
attrs
[
FusedSoftmaxAttr
.
ST_TYPE
]
praxis_p
=
pax_fiddle
.
Config
(
FusedSoftmax
,
name
=
"fused_softmax"
,
scale_factor
=
scale_factor
,
softmax_type
=
softmax_type
)
flax_cls
=
partial
(
Softmax
,
scale_factor
=
scale_factor
,
softmax_type
=
softmax_type
)
return
praxis_p
,
flax_cls
def
sync_variables
(
self
,
praxis_variables
,
flax_variables
):
return
praxis_variables
,
flax_variables
def
sync_wgrads
(
self
,
praxis_wgrads
,
flax_wgrads
):
return
praxis_wgrads
,
flax_wgrads
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[(
32
,
1
,
128
,
128
),
(
32
,
1
,
512
,
128
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
FusedSoftmaxAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
if
(
attrs
[
FusedSoftmaxAttr
.
ST_TYPE
]
==
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
)
and
(
data_shape
[
-
2
]
!=
data_shape
[
-
1
]
):
pass
# Skip, due to not support
else
:
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
LinearAttr
:
FEATURE
=
"features"
USE_BIAS
=
"use_bias"
ATTRS
=
[
{
FEATURE
:
512
,
USE_BIAS
:
False
},
{
FEATURE
:
512
,
USE_BIAS
:
True
},
{
FEATURE
:
1024
,
USE_BIAS
:
False
},
{
FEATURE
:
1024
,
USE_BIAS
:
True
},
]
class
TestLinear
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
data_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
return
(
jax
.
random
.
normal
(
data_key
,
shape
,
dtype
),)
def
get_layer_name
(
self
):
return
"linear"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
out_features
=
attrs
[
LinearAttr
.
FEATURE
]
kernel_init
=
WeightInit
.
Gaussian
(
1.0
)
use_bias
=
attrs
[
LinearAttr
.
USE_BIAS
]
bias_init
=
WeightInit
.
Constant
(
0.0
)
axis
=
-
1
transpose_batch_sequence
=
False
praxis_p
=
pax_fiddle
.
Config
(
Linear
,
name
=
"linear"
,
dtype
=
dtype
,
out_features
=
out_features
,
params_init
=
kernel_init
,
use_bias
=
use_bias
,
bias_init
=
bias_init
,
axis
=
axis
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
flax_cls
=
partial
(
DenseGeneral
,
features
=
out_features
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
kernel_init
),
use_bias
=
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
bias_init
),
axis
=
axis
,
dtype
=
dtype
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LinearAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LinearAttr
.
ATTRS
)
@
pytest
.
mark
.
parametrize
(
"fp8_format"
,
FP8_FORMATS
)
def
test_forward_backward_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_format
,
rtol
=
1e-05
,
atol
=
1e-08
):
ds
=
DelayedScaling
(
fp8_format
=
fp8_format
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
LayerNormLinearAttr
:
FEATURE
=
"features"
USE_BIAS
=
"use_bias"
ENABLE_LN
=
"enable_layernorm"
LN_TYPE
=
"layernorm_type"
ZERO_CEN
=
"zero_centered_gamma"
ATTRS
=
[
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
False
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
},
]
class
TestLayerNormLinear
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
data_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
return
(
jax
.
random
.
normal
(
data_key
,
shape
,
dtype
),)
def
get_layer_name
(
self
):
return
"ln_linear"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
out_features
=
attrs
[
LayerNormLinearAttr
.
FEATURE
]
enable_layernorm
=
attrs
[
LayerNormLinearAttr
.
ENABLE_LN
]
layernorm_type
=
attrs
[
LayerNormLinearAttr
.
LN_TYPE
]
zero_centered_gamma
=
attrs
[
LayerNormLinearAttr
.
ZERO_CEN
]
kernel_init
=
WeightInit
.
Gaussian
(
1.0
)
use_bias
=
attrs
[
LayerNormLinearAttr
.
USE_BIAS
]
bias_init
=
WeightInit
.
Constant
(
0.0
)
axis
=
-
1
transpose_batch_sequence
=
False
praxis_p
=
pax_fiddle
.
Config
(
LayerNormLinear
,
name
=
"ln_linear"
,
dtype
=
dtype
,
out_features
=
out_features
,
enable_layernorm
=
enable_layernorm
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
params_init
=
kernel_init
,
use_bias
=
use_bias
,
bias_init
=
bias_init
,
axis
=
axis
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
flax_cls
=
partial
(
LayerNormDenseGeneral
,
features
=
out_features
,
enable_layernorm
=
enable_layernorm
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
kernel_init
),
use_bias
=
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
bias_init
),
axis
=
axis
,
dtype
=
dtype
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LayerNormLinearAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LayerNormLinearAttr
.
ATTRS
)
@
pytest
.
mark
.
parametrize
(
"fp8_format"
,
FP8_FORMATS
)
def
test_forward_backward_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_format
,
rtol
=
1e-05
,
atol
=
1e-08
):
ds
=
DelayedScaling
(
fp8_format
=
fp8_format
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
LayerNormMLPAttr
:
INTERMEDIATE_DIM
=
"intermediate_dim"
USE_BIAS
=
"use_bias"
ENABLE_LN
=
"enable_layernorm"
LN_TYPE
=
"layernorm_type"
ZERO_CEN
=
"zero_centered_gamma"
ACTIVATION
=
"activations"
ATTRS
=
[
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"relu"
,),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
False
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"silu"
,
"linear"
),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
False
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"silu"
,
"linear"
),
},
]
class
TestLayerNormMLP
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
data_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
return
(
jax
.
random
.
normal
(
data_key
,
shape
,
dtype
),)
def
get_layer_name
(
self
):
return
"ln_mlp"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
intermediate_dim
=
attrs
[
LayerNormMLPAttr
.
INTERMEDIATE_DIM
]
enable_layernorm
=
attrs
[
LayerNormMLPAttr
.
ENABLE_LN
]
layernorm_type
=
attrs
[
LayerNormMLPAttr
.
LN_TYPE
]
zero_centered_gamma
=
attrs
[
LayerNormMLPAttr
.
ZERO_CEN
]
kernel_init
=
WeightInit
.
Gaussian
(
1.0
)
use_bias
=
attrs
[
LayerNormMLPAttr
.
USE_BIAS
]
bias_init
=
WeightInit
.
Constant
(
0.0
)
activations
=
attrs
[
LayerNormMLPAttr
.
ACTIVATION
]
axis
=
-
1
transpose_batch_sequence
=
False
praxis_p
=
pax_fiddle
.
Config
(
LayerNormMLP
,
name
=
"ln_mlp"
,
dtype
=
dtype
,
intermediate_dim
=
intermediate_dim
,
enable_layernorm
=
enable_layernorm
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
params_init
=
kernel_init
,
use_bias
=
use_bias
,
bias_init
=
bias_init
,
activations
=
activations
,
intermediate_dropout_rate
=
0.0
,
axis
=
axis
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
flax_cls
=
partial
(
flax_LayerNormMLP
,
intermediate_dim
=
intermediate_dim
,
enable_layernorm
=
enable_layernorm
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
kernel_init
),
use_bias
=
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
bias_init
),
activations
=
activations
,
intermediate_dropout_rate
=
0.0
,
axis
=
axis
,
dtype
=
dtype
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LayerNormMLPAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LayerNormMLPAttr
.
ATTRS
)
@
pytest
.
mark
.
parametrize
(
"fp8_format"
,
FP8_FORMATS
)
def
test_forward_backward_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_format
,
rtol
=
1e-05
,
atol
=
1e-08
):
ds
=
DelayedScaling
(
fp8_format
=
fp8_format
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
TestRelativePositionBias
(
TestLayer
):
def
get_layer_name
(
self
):
return
"relative_position_bias"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
num_buckets
=
32
max_distance
=
128
num_attention_heads
=
64
rb_stddev
=
(
num_attention_heads
*
num_buckets
)
**
-
0.5
embedding_init
=
WeightInit
.
Gaussian
(
rb_stddev
)
praxis_p
=
pax_fiddle
.
Config
(
RelativePositionBiases
,
name
=
"relative_position_bias"
,
dtype
=
dtype
,
num_buckets
=
num_buckets
,
max_distance
=
max_distance
,
num_attention_heads
=
num_attention_heads
,
embedding_init
=
embedding_init
,
)
flax_cls
=
partial
(
flax_RelativePositionBiases
,
num_buckets
=
num_buckets
,
max_distance
=
max_distance
,
num_attention_heads
=
num_attention_heads
,
embedding_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"rel_embedding"
,
embedding_init
),
dtype
=
dtype
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
[{}])
def
test_forward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
init_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
test_inputs
=
[(
128
,
128
,
True
),
(
128
,
128
,
False
)]
for
test_input
in
test_inputs
:
praxis_layer
=
praxis_p
.
Instantiate
()
praxis_variables
=
praxis_layer
.
init
(
init_key
,
*
test_input
)
flax_layer
=
flax_cls
()
flax_variables
=
flax_layer
.
init
(
init_key
,
*
test_input
)
if
"params_axes"
in
flax_variables
:
flax_variables
,
_
=
flax
.
core
.
pop
(
flax_variables
,
"params_axes"
)
if
FP8Helper
.
is_fp8_enabled
():
flax_variables
,
_
=
flax
.
core
.
pop
(
flax_variables
,
FP8Helper
.
FP8_COLLECTION_NAME
+
"_axes"
)
praxis_variables
,
flax_variables
=
self
.
sync_variables
(
praxis_variables
,
flax_variables
)
praxis_loss
=
TestLayer
.
loss
(
praxis_variables
,
*
test_input
,
module
=
praxis_layer
,
mean_out
=
False
)
flax_loss
=
TestLayer
.
loss
(
flax_variables
,
*
test_input
,
module
=
flax_layer
,
mean_out
=
False
)
assert_allclose
(
praxis_loss
,
flax_loss
,
rtol
=
rtol
,
atol
=
atol
)
class
DotProductAttnAttr
:
ATTN_MASK_TYPE
=
"attn_mask_type"
NUM_GQA_GROUPS
=
"num_gqa_groups"
TRANSPOSE_BS
=
"transpose_batch_sequence"
SCALE_FACTOR
=
"scale_factor"
WINDOW_SIZE
=
"window_size"
ATTRS
=
[
{
ATTN_MASK_TYPE
:
"padding"
,
TRANSPOSE_BS
:
True
,
SCALE_FACTOR
:
0.125
,
},
{
ATTN_MASK_TYPE
:
"padding_causal"
,
TRANSPOSE_BS
:
True
,
SCALE_FACTOR
:
0.125
,
},
{
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
True
,
SCALE_FACTOR
:
0.125
,
},
{
ATTN_MASK_TYPE
:
"padding"
,
TRANSPOSE_BS
:
False
,
SCALE_FACTOR
:
0.125
,
},
{
ATTN_MASK_TYPE
:
"padding_causal"
,
TRANSPOSE_BS
:
False
,
SCALE_FACTOR
:
2.0
,
},
{
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
False
,
SCALE_FACTOR
:
1.0
,
},
{
ATTN_MASK_TYPE
:
"no_mask"
,
TRANSPOSE_BS
:
False
,
SCALE_FACTOR
:
1.0
,
},
{
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
False
,
SCALE_FACTOR
:
1.0
,
WINDOW_SIZE
:
(
64
,
0
),
# Left size must <= S in DATA_SHAPE
},
]
class
TestDotProductAttn
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
q_key
,
k_key
,
v_key
=
jax
.
random
.
split
(
key
,
3
)
b
,
s
,
*
_
=
shape
if
self
.
attrs
[
DotProductAttnAttr
.
TRANSPOSE_BS
]:
shape
=
(
shape
[
1
],
shape
[
0
])
+
shape
[
2
:]
mask
=
jnp
.
zeros
((
b
,
1
,
s
,
s
),
dtype
=
jnp
.
uint8
)
return
[
*
map
(
partial
(
jax
.
random
.
normal
,
shape
=
shape
,
dtype
=
dtype
),
[
q_key
,
k_key
,
v_key
]),
mask
,
]
def
get_layer_name
(
self
):
return
"dot_product_attn"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
head_dim
=
64
num_attention_heads
=
16
num_gqa_groups
=
num_attention_heads
attn_mask_type
=
attrs
[
DotProductAttnAttr
.
ATTN_MASK_TYPE
]
transpose_batch_sequence
=
attrs
[
DotProductAttnAttr
.
TRANSPOSE_BS
]
window_size
=
attrs
.
get
(
DotProductAttnAttr
.
WINDOW_SIZE
,
None
)
praxis_p
=
pax_fiddle
.
Config
(
DotProductAttention
,
name
=
"mha"
,
dtype
=
dtype
,
head_dim
=
head_dim
,
num_attention_heads
=
num_attention_heads
,
num_gqa_groups
=
num_gqa_groups
,
attn_mask_type
=
attn_mask_type
,
transpose_batch_sequence
=
transpose_batch_sequence
,
window_size
=
window_size
,
)
flax_cls
=
partial
(
flax_DotProductAttention
,
dtype
=
dtype
,
head_dim
=
head_dim
,
num_attention_heads
=
num_attention_heads
,
num_gqa_groups
=
num_gqa_groups
,
attn_mask_type
=
attn_mask_type
,
transpose_batch_sequence
=
transpose_batch_sequence
,
window_size
=
window_size
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[(
32
,
128
,
16
,
64
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
DotProductAttnAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
self
.
attrs
=
attrs
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
MultiHeadAttnAttr
:
USE_BIAS
=
"use_bias"
LN_TYPE
=
"layernorm_type"
ATTN_MASK_TYPE
=
"attn_mask_type"
ZERO_CEN
=
"zero_centered_gamma"
NUM_ATTN_HEADS
=
"num_attention_heads"
NUM_GQA_GROUPS
=
"num_gqa_groups"
TRANSPOSE_BS
=
"transpose_batch_sequence"
ENABLE_ROPE
=
"enable_rotary_pos_emb"
ROPE_GROUP_METHOD
=
"rotary_pos_emb_group_method"
LORA_SCOPE
=
"low_rank_adaptation_scope"
WINDOW_SIZE
=
"window_size"
ATTRS
=
[
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"padding"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"padding"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"padding"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
NUM_ATTN_HEADS
:
8
,
NUM_GQA_GROUPS
:
4
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"consecutive"
,
NUM_ATTN_HEADS
:
8
,
NUM_GQA_GROUPS
:
4
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"alternate"
,
NUM_ATTN_HEADS
:
8
,
NUM_GQA_GROUPS
:
4
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"padding"
,
LORA_SCOPE
:
"all"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"causal"
,
LORA_SCOPE
:
"all"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"causal"
,
LORA_SCOPE
:
"all"
,
TRANSPOSE_BS
:
True
,
WINDOW_SIZE
:
(
64
,
0
),
# Left size must <= S in DATA_SHAPE
},
]
class
TestMultiHeadAttn
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
q_key
,
kv_key
=
jax
.
random
.
split
(
key
,
2
)
b
,
s
,
*
_
=
shape
if
self
.
attrs
[
MultiHeadAttnAttr
.
TRANSPOSE_BS
]:
shape
=
(
shape
[
1
],
shape
[
0
])
+
shape
[
2
:]
mask
=
jnp
.
zeros
((
b
,
1
,
s
,
s
),
dtype
=
jnp
.
uint8
)
return
[
*
map
(
partial
(
jax
.
random
.
normal
,
shape
=
shape
,
dtype
=
dtype
),
[
q_key
,
kv_key
]),
mask
]
def
get_layer_name
(
self
):
return
"multi_head_attn"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
head_dim
=
64
num_attention_heads
=
16
num_gqa_groups
=
(
attrs
[
MultiHeadAttnAttr
.
NUM_GQA_GROUPS
]
if
MultiHeadAttnAttr
.
NUM_GQA_GROUPS
in
attrs
else
None
)
layernorm_type
=
attrs
[
MultiHeadAttnAttr
.
LN_TYPE
]
zero_centered_gamma
=
attrs
[
MultiHeadAttnAttr
.
ZERO_CEN
]
kernel_init
=
WeightInit
.
Gaussian
(
1.0
)
use_bias
=
attrs
[
MultiHeadAttnAttr
.
USE_BIAS
]
bias_init
=
WeightInit
.
Constant
(
0.0
)
input_layernorm
=
False
return_layernorm_output
=
False
attn_mask_type
=
attrs
[
MultiHeadAttnAttr
.
ATTN_MASK_TYPE
]
enable_rotary_pos_emb
=
attrs
[
MultiHeadAttnAttr
.
ENABLE_ROPE
]
rotary_pos_emb_group_method
=
attrs
[
MultiHeadAttnAttr
.
ROPE_GROUP_METHOD
]
low_rank_adaptation_scope
=
attrs
.
get
(
MultiHeadAttnAttr
.
LORA_SCOPE
,
"none"
)
fuse_qkv_params
=
True
transpose_batch_sequence
=
attrs
[
MultiHeadAttnAttr
.
TRANSPOSE_BS
]
scale_attn_logits
=
False
scaled_query_init
=
True
float32_logits
=
False
window_size
=
attrs
.
get
(
MultiHeadAttnAttr
.
WINDOW_SIZE
,
None
)
praxis_p
=
pax_fiddle
.
Config
(
MultiHeadAttention
,
name
=
"mha"
,
dtype
=
dtype
,
head_dim
=
head_dim
,
num_attention_heads
=
num_attention_heads
,
num_gqa_groups
=
num_gqa_groups
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
params_init
=
kernel_init
,
use_bias
=
use_bias
,
bias_init
=
bias_init
,
return_layernorm_output
=
return_layernorm_output
,
input_layernorm
=
input_layernorm
,
attn_mask_type
=
attn_mask_type
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
fuse_qkv_params
=
fuse_qkv_params
,
transpose_batch_sequence
=
transpose_batch_sequence
,
scale_attn_logits
=
scale_attn_logits
,
scaled_query_init
=
scaled_query_init
,
float32_logits
=
float32_logits
,
window_size
=
window_size
,
)
flax_cls
=
partial
(
flax_MultiHeadAttention
,
dtype
=
dtype
,
head_dim
=
head_dim
,
num_attention_heads
=
num_attention_heads
,
num_gqa_groups
=
num_gqa_groups
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
kernel_init
),
use_bias
=
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
bias_init
),
return_layernorm_output
=
return_layernorm_output
,
input_layernorm
=
input_layernorm
,
attn_mask_type
=
attn_mask_type
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
fuse_qkv_params
=
fuse_qkv_params
,
transpose_batch_sequence
=
transpose_batch_sequence
,
scale_attn_logits
=
scale_attn_logits
,
scaled_query_init
=
scaled_query_init
,
float32_logits
=
float32_logits
,
window_size
=
window_size
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
MultiHeadAttnAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
self
.
attrs
=
attrs
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
MultiHeadAttnAttr
.
ATTRS
)
@
pytest
.
mark
.
parametrize
(
"fp8_format"
,
FP8_FORMATS
)
def
test_forward_backward_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_format
,
rtol
=
1e-05
,
atol
=
1e-08
):
self
.
attrs
=
attrs
ds
=
DelayedScaling
(
fp8_format
=
fp8_format
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
TransformerLayerAttr
:
USE_BIAS
=
"use_bias"
LN_TYPE
=
"layernorm_type"
ACTIVATION
=
"activations"
LYR_TYPE
=
"layer_type"
ZERO_CEN
=
"zero_centered_gamma"
TRANSPOSE_BS
=
"transpose_batch_sequence"
ENABLE_ROPE
=
"enable_rotary_pos_emb"
ROPE_GROUP_METHOD
=
"rotary_pos_emb_group_method"
LORA_SCOPE
=
"low_rank_adaptation_scope"
WINDOW_SIZE
=
"window_size"
ATTRS
=
[
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
LORA_SCOPE
:
"all"
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"alternate"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"alternate"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
LORA_SCOPE
:
"all"
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
WINDOW_SIZE
:
(
64
,
0
),
# Left size must <= S in DATA_SHAPE
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
WINDOW_SIZE
:
(
64
,
0
),
# Left size must <= S in DATA_SHAPE
},
]
class
TestTransformer
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
q_key
,
kv_key
=
jax
.
random
.
split
(
key
,
2
)
b
,
s
,
*
_
=
shape
if
self
.
attrs
[
TransformerLayerAttr
.
TRANSPOSE_BS
]:
shape
=
(
shape
[
1
],
shape
[
0
])
+
shape
[
2
:]
mask
=
jnp
.
zeros
((
b
,
1
,
s
,
s
),
dtype
=
jnp
.
uint8
)
return
[
*
map
(
partial
(
jax
.
random
.
normal
,
shape
=
shape
,
dtype
=
dtype
),
[
q_key
,
kv_key
]),
mask
,
mask
,
]
def
get_layer_name
(
self
):
return
"transformerlayer"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
hidden_size
=
512
mlp_hidden_size
=
2048
num_attention_heads
=
8
layernorm_type
=
attrs
[
TransformerLayerAttr
.
LN_TYPE
]
hidden_dropout
=
0.0
attention_dropout
=
0.0
intermediate_dropout
=
0.0
mlp_activations
=
attrs
[
TransformerLayerAttr
.
ACTIVATION
]
kernel_init
=
WeightInit
.
Gaussian
(
1.0
)
use_bias
=
attrs
[
TransformerLayerAttr
.
USE_BIAS
]
bias_init
=
WeightInit
.
Constant
(
0.0
)
layer_type
=
attrs
[
TransformerLayerAttr
.
LYR_TYPE
]
enable_rotary_pos_emb
=
attrs
[
TransformerLayerAttr
.
ENABLE_ROPE
]
rotary_pos_emb_group_method
=
attrs
[
TransformerLayerAttr
.
ROPE_GROUP_METHOD
]
low_rank_adaptation_scope
=
attrs
.
get
(
TransformerLayerAttr
.
LORA_SCOPE
,
"none"
)
enable_relative_embedding
=
True
relative_embedding
=
pax_fiddle
.
Config
(
RelativePositionBiases
,
dtype
=
dtype
,
num_attention_heads
=
num_attention_heads
)
drop_path
=
0.0
transpose_batch_sequence
=
attrs
[
TransformerLayerAttr
.
TRANSPOSE_BS
]
window_size
=
attrs
.
get
(
TransformerLayerAttr
.
WINDOW_SIZE
,
None
)
rel_embedding_init
=
RelativePositionBiases
.
generate_embedding_init
(
relative_embedding
.
embedding_init
,
relative_embedding
.
num_attention_heads
,
relative_embedding
.
num_buckets
,
)
relative_embedding_flax_module
=
flax_RelativePositionBiases
(
num_buckets
=
relative_embedding
.
num_buckets
,
max_distance
=
relative_embedding
.
max_distance
,
num_attention_heads
=
relative_embedding
.
num_attention_heads
,
embedding_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"rel_embedding"
,
rel_embedding_init
),
embedding_axes
=
relative_embedding
.
embedding_axes
,
dtype
=
relative_embedding
.
dtype
,
)
praxis_p
=
pax_fiddle
.
Config
(
TransformerLayer
,
name
=
"transformer_layer"
,
params_init
=
kernel_init
,
dtype
=
dtype
,
hidden_size
=
hidden_size
,
mlp_hidden_size
=
mlp_hidden_size
,
num_attention_heads
=
num_attention_heads
,
layernorm_type
=
layernorm_type
,
hidden_dropout
=
hidden_dropout
,
attention_dropout
=
attention_dropout
,
intermediate_dropout
=
intermediate_dropout
,
mlp_activations
=
mlp_activations
,
use_bias
=
use_bias
,
bias_init
=
bias_init
,
layer_type
=
layer_type
,
enable_relative_embedding
=
enable_relative_embedding
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
relative_embedding
=
relative_embedding
,
drop_path
=
drop_path
,
transpose_batch_sequence
=
transpose_batch_sequence
,
window_size
=
window_size
,
)
flax_cls
=
partial
(
flax_TransformerLayer
,
dtype
=
dtype
,
hidden_size
=
hidden_size
,
mlp_hidden_size
=
mlp_hidden_size
,
num_attention_heads
=
num_attention_heads
,
layernorm_type
=
layernorm_type
,
hidden_dropout
=
hidden_dropout
,
attention_dropout
=
attention_dropout
,
intermediate_dropout
=
intermediate_dropout
,
mlp_activations
=
mlp_activations
,
mha_kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"mha_kernel"
,
kernel_init
),
mlp_kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"mlp_kernel"
,
kernel_init
),
use_bias
=
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
bias_init
),
layer_type
=
layer_type
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
enable_relative_embedding
=
enable_relative_embedding
,
relative_embedding
=
relative_embedding_flax_module
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
drop_path
=
drop_path
,
transpose_batch_sequence
=
transpose_batch_sequence
,
window_size
=
window_size
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
TransformerLayerAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
self
.
attrs
=
attrs
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
TransformerLayerAttr
.
ATTRS
)
@
pytest
.
mark
.
parametrize
(
"fp8_format"
,
FP8_FORMATS
)
def
test_forward_backward_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_format
,
rtol
=
1e-05
,
atol
=
1e-08
):
self
.
attrs
=
attrs
ds
=
DelayedScaling
(
fp8_format
=
fp8_format
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment