Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a207db1d
Commit
a207db1d
authored
Apr 01, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
fbee8990
69365f88
Changes
101
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1659 additions
and
2477 deletions
+1659
-2477
docs/api/pytorch.rst
docs/api/pytorch.rst
+1
-0
examples/jax/encoder/common.py
examples/jax/encoder/common.py
+20
-0
examples/jax/encoder/run_test_multiprocessing_encoder.sh
examples/jax/encoder/run_test_multiprocessing_encoder.sh
+7
-1
examples/jax/encoder/test_model_parallel_encoder.py
examples/jax/encoder/test_model_parallel_encoder.py
+63
-22
examples/jax/encoder/test_multigpu_encoder.py
examples/jax/encoder/test_multigpu_encoder.py
+45
-15
examples/jax/encoder/test_multiprocessing_encoder.py
examples/jax/encoder/test_multiprocessing_encoder.py
+59
-31
examples/jax/encoder/test_single_gpu_encoder.py
examples/jax/encoder/test_single_gpu_encoder.py
+31
-10
examples/jax/mnist/test_single_gpu_mnist.py
examples/jax/mnist/test_single_gpu_mnist.py
+41
-10
qa/L0_jax_unittest/test.sh
qa/L0_jax_unittest/test.sh
+3
-4
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-1
qa/L2_jax_unittest/test.sh
qa/L2_jax_unittest/test.sh
+23
-0
tests/jax/distributed_test_base.py
tests/jax/distributed_test_base.py
+1
-1
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+1107
-747
tests/jax/test_distributed_fused_attn.py
tests/jax/test_distributed_fused_attn.py
+4
-5
tests/jax/test_distributed_layernorm.py
tests/jax/test_distributed_layernorm.py
+65
-21
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+102
-99
tests/jax/test_distributed_softmax.py
tests/jax/test_distributed_softmax.py
+1
-1
tests/jax/test_helper.py
tests/jax/test_helper.py
+22
-22
tests/jax/test_layer.py
tests/jax/test_layer.py
+63
-51
tests/jax/test_praxis_layers.py
tests/jax/test_praxis_layers.py
+0
-1436
No files found.
docs/api/pytorch.rst
View file @
a207db1d
...
...
@@ -32,6 +32,7 @@ pyTorch
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
:members: reset, allocate_memory, pre_step, get_seqlens_pre_step, convert_paged_to_nonpaged, step
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
:members: reset, get_states, set_states, add, fork
...
...
examples/jax/encoder/common.py
View file @
a207db1d
...
...
@@ -4,7 +4,9 @@
"""Shared functions for the encoder tests"""
from
functools
import
lru_cache
import
transformer_engine
from
transformer_engine_jax
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
@
lru_cache
...
...
@@ -19,3 +21,21 @@ def is_fp8_supported():
"""Return if FP8 has hardware supported"""
gpu_arch
=
get_device_compute_capability
(
0
)
return
gpu_arch
>=
90
@
lru_cache
def
is_mxfp8_supported
():
"""Return if FP8 has hardware supported"""
gpu_arch
=
get_device_compute_capability
(
0
)
return
gpu_arch
>=
100
def
get_fp8_recipe_from_name_string
(
name
:
str
):
"""Query recipe from a given name string"""
match
name
:
case
"DelayedScaling"
:
return
recipe
.
DelayedScaling
()
case
"MXFP8BlockScaling"
:
return
recipe
.
MXFP8BlockScaling
()
case
_
:
raise
ValueError
(
f
"Invalid fp8_recipe, got
{
name
}
"
)
examples/jax/encoder/run_test_multiprocessing_encoder.sh
View file @
a207db1d
...
...
@@ -12,6 +12,12 @@ wait
for
i
in
$(
seq
0
$((
$NUM_GPUS
-
1
))
)
do
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8
--num-process
=
$NUM_GPUS
--process-id
=
$i
&
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8
--num-process
=
$NUM_GPUS
--process-id
=
$i
&
done
wait
for
i
in
$(
seq
0
$((
$NUM_GPUS
-
1
))
)
do
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8
--num-process
=
$NUM_GPUS
--process-id
=
$i
&
done
wait
examples/jax/encoder/test_model_parallel_encoder.py
View file @
a207db1d
...
...
@@ -19,10 +19,11 @@ from flax.training import train_state
from
jax.experimental
import
mesh_utils
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
from
common
import
is_bf16_supported
DEVICE_DP_AXIS
=
"data"
DEVICE_TP_AXIS
=
"model"
...
...
@@ -217,9 +218,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
)
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
...
...
@@ -272,6 +272,19 @@ def train_and_evaluate(args):
args
.
test_batch_size
%
num_gpu_dp
==
0
),
f
"Test batch size needs to be multiple of
{
num_gpu_dp
}
"
if
args
.
fp8_recipe
==
"MXFP8BlockScaling"
:
assert
(
args
.
batch_size
/
num_gpu_dp
%
32
==
0
),
"Batch size needs to be multiple of 32 for MXFP8"
assert
(
args
.
test_batch_size
/
num_gpu_dp
%
32
==
0
),
"Test batch size needs to be multiple of 32 for MXFP8"
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
)
...
...
@@ -287,7 +300,9 @@ def train_and_evaluate(args):
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
args
.
use_fp8
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
)
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
):
encoder
=
Net
(
num_embed
,
args
.
enable_sp
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
...
@@ -371,21 +386,21 @@ def encoder_parser(args):
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
help
=
"input batch size for training (default:
64
)"
,
help
=
"input batch size for training (default:
128
)"
,
)
parser
.
add_argument
(
"--test-batch-size"
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
help
=
"input batch size for testing (default:
64
)"
,
help
=
"input batch size for testing (default:
128
)"
,
)
parser
.
add_argument
(
"--max-seq-len"
,
type
=
int
,
default
=
32
,
default
=
64
,
metavar
=
"N"
,
help
=
"maximum sequence length (default: 32)"
,
)
...
...
@@ -416,6 +431,12 @@ def encoder_parser(args):
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
parser
.
add_argument
(
"--enable-sp"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enable sequence parallelism."
)
...
...
@@ -426,7 +447,8 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -437,29 +459,48 @@ class TestEncoder(unittest.TestCase):
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.45
and
actual
[
1
]
>
0.79
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
def
test_te_fp8
(
self
):
"""Test Transformer Engine with FP8"""
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_
reason
)
def
test_te_
mx
fp8
(
self
):
"""Test Transformer Engine with
MX
FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
45
5
and
actual
[
1
]
>
0.7
85
assert
actual
[
0
]
<
0.5
0
and
actual
[
1
]
>
0.7
6
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16_sp
(
self
):
def
test_te_bf16_
with_
sp
(
self
):
"""Test Transformer Engine with BF16 + SP"""
self
.
args
.
enable_sp
=
True
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.45
and
actual
[
1
]
>
0.79
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8_with_sp
(
self
):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self
.
args
.
enable_sp
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
def
test_te_fp8_sp
(
self
):
"""Test Transformer Engine with FP8 + SP"""
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_
reason
)
def
test_te_
mx
fp8_
with_
sp
(
self
):
"""Test Transformer Engine with
MX
FP8 + SP"""
self
.
args
.
enable_sp
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
45
5
and
actual
[
1
]
>
0.7
85
assert
actual
[
0
]
<
0.5
0
and
actual
[
1
]
>
0.7
6
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_multigpu_encoder.py
View file @
a207db1d
...
...
@@ -19,10 +19,11 @@ from flax.training import train_state
from
jax.experimental
import
mesh_utils
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
from
common
import
is_bf16_supported
DEVICE_DP_AXIS
=
"data"
PARAMS_KEY
=
"params"
...
...
@@ -198,9 +199,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
)
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
...
...
@@ -243,6 +243,18 @@ def train_and_evaluate(args):
num_gpu
=
jax
.
local_device_count
()
assert
args
.
batch_size
%
num_gpu
==
0
,
f
"Batch size needs to be multiple of
{
num_gpu
}
"
assert
args
.
test_batch_size
%
num_gpu
==
0
,
f
"Test batch size needs to be multiple of
{
num_gpu
}
"
if
args
.
fp8_recipe
==
"MXFP8BlockScaling"
:
assert
(
args
.
batch_size
/
num_gpu
%
32
==
0
),
"Batch size needs to be multiple of 32 for MXFP8"
assert
(
args
.
test_batch_size
/
num_gpu
%
32
==
0
),
"Test batch size needs to be multiple of 32 for MXFP8"
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu
,))
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,))
as
mesh
:
...
...
@@ -257,7 +269,9 @@ def train_and_evaluate(args):
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
args
.
use_fp8
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
None
,
None
,
None
)
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
None
,
None
,
None
),
):
encoder
=
Net
(
num_embed
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
...
@@ -344,16 +358,16 @@ def encoder_parser(args):
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
128
,
default
=
256
,
metavar
=
"N"
,
help
=
"input batch size for training (default:
128
)"
,
help
=
"input batch size for training (default:
256
)"
,
)
parser
.
add_argument
(
"--test-batch-size"
,
type
=
int
,
default
=
128
,
default
=
256
,
metavar
=
"N"
,
help
=
"input batch size for testing (default:
128
)"
,
help
=
"input batch size for testing (default:
256
)"
,
)
parser
.
add_argument
(
"--max-seq-len"
,
...
...
@@ -389,6 +403,12 @@ def encoder_parser(args):
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
return
parser
.
parse_args
(
args
)
...
...
@@ -396,7 +416,8 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -407,14 +428,23 @@ class TestEncoder(unittest.TestCase):
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
assert
actual
[
0
]
<
0.535
and
actual
[
1
]
>
0.73
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.535
and
actual
[
1
]
>
0.73
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
def
test_te_fp8
(
self
):
"""Test Transformer Engine with FP8"""
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_
reason
)
def
test_te_
mx
fp8
(
self
):
"""Test Transformer Engine with
MX
FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.5
0
and
actual
[
1
]
>
0.7
6
assert
actual
[
0
]
<
0.5
35
and
actual
[
1
]
>
0.7
3
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_multiprocessing_encoder.py
View file @
a207db1d
...
...
@@ -21,9 +21,15 @@ from flax.training import train_state
from
jax.experimental
import
mesh_utils
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
common
import
is_bf16_supported
,
is_fp8_supported
from
common
import
(
is_bf16_supported
,
is_fp8_supported
,
is_mxfp8_supported
,
get_fp8_recipe_from_name_string
,
)
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
os
.
environ
[
"CUDA_DEVICE_ORDER"
]
=
"PCI_BUS_ID"
...
...
@@ -298,9 +304,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
)
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
...
...
@@ -359,10 +364,16 @@ def train_and_evaluate(args):
num_gpu_dp
=
1
num_gpu_tp
=
1
assert
args
.
batch_size
%
num_gpu_dp
==
0
,
f
"Batch size needs to be multiple of
{
num_gpu_dp
}
"
if
args
.
fp8_recipe
==
"MXFP8BlockScaling"
:
assert
args
.
batch_size
%
32
==
0
,
"Batch size needs to be multiple of 32 for MXFP8"
assert
(
args
.
test_batch_size
%
num_gpu_dp
==
0
),
f
"Test batch size needs to be multiple of
{
num_gpu_dp
}
"
args
.
test_batch_size
%
32
==
0
),
"Test batch size needs to be multiple of 32 for MXFP8"
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
with
jax
.
sharding
.
Mesh
(
...
...
@@ -379,7 +390,9 @@ def train_and_evaluate(args):
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
args
.
use_fp8
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
)
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
):
encoder
=
Net
(
num_embed
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
...
@@ -482,23 +495,23 @@ def encoder_parser(args):
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
help
=
"input batch size for training (default:
64
)"
,
help
=
"input batch size for training (default:
128
)"
,
)
parser
.
add_argument
(
"--test-batch-size"
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
help
=
"input batch size for testing (default:
64
)"
,
help
=
"input batch size for testing (default:
128
)"
,
)
parser
.
add_argument
(
"--max-seq-len"
,
type
=
int
,
default
=
32
,
default
=
64
,
metavar
=
"N"
,
help
=
"maximum sequence length (default:
32
)"
,
help
=
"maximum sequence length (default:
64
)"
,
)
parser
.
add_argument
(
"--epochs"
,
...
...
@@ -527,6 +540,12 @@ def encoder_parser(args):
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
parser
.
add_argument
(
"--coordinator-address"
,
type
=
str
,
...
...
@@ -554,37 +573,46 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
gpu_has_fp8
=
is_fp8_supported
()
gpu_has_bf16
=
is_bf16_supported
()
def
exec
(
self
,
use_fp8
):
def
exec
(
self
,
use_fp8
,
fp8_recipe
):
"""Run 3 epochs for testing"""
args
=
encoder_parser
([])
num_gpu
=
self
.
num_process
tp_size
=
2
if
num_gpu
>
1
and
num_gpu
%
2
==
0
else
1
dp_size
=
num_gpu
//
tp_size
batch_size
=
64
//
dp_size
assert
args
.
batch_size
%
dp_size
==
0
,
f
"Batch size needs to be multiple of
{
dp_size
}
"
batch_size
=
args
.
batch_size
//
dp_size
args
.
use_fp8
=
use_fp8
args
.
batch_size
=
batch_size
args
.
test_batch_size
=
batch_size
args
.
num_process
=
num_gpu
args
.
process_id
=
self
.
process_id
args
.
fp8_recipe
=
fp8_recipe
return
train_and_evaluate
(
args
)
@
unittest
.
skipIf
(
not
gpu_has_bf16
,
"Device compute capability 8.0+ is required for BF16"
)
@
unittest
.
skipIf
(
not
is_bf16_supported
()
,
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
result
=
self
.
exec
(
False
)
assert
result
[
0
]
<
0.45
and
result
[
1
]
>
0.79
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
"Device compute capability 9.0+ is required for FP8"
)
def
test_te_fp8
(
self
):
"""Test Transformer Engine with FP8"""
result
=
self
.
exec
(
True
)
assert
result
[
0
]
<
0.455
and
result
[
1
]
>
0.79
result
=
self
.
exec
(
False
,
None
)
assert
result
[
0
]
<
0.505
and
result
[
1
]
>
0.755
@
unittest
.
skipIf
(
not
is_fp8_supported
(),
"Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
result
=
self
.
exec
(
True
,
"DelayedScaling"
)
assert
result
[
0
]
<
0.505
and
result
[
1
]
>
0.755
@
unittest
.
skipIf
(
not
is_mxfp8_supported
(),
"Device compute capability 10.0+ is required for MXFP8"
)
def
test_te_mxfp8
(
self
):
"""Test Transformer Engine with MXFP8"""
result
=
self
.
exec
(
True
,
"MXFP8BlockScaling"
)
assert
result
[
0
]
<
0.505
and
result
[
1
]
>
0.754
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_single_gpu_encoder.py
View file @
a207db1d
...
...
@@ -16,10 +16,11 @@ from datasets import load_dataset
from
flax
import
linen
as
nn
from
flax.training
import
train_state
from
common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
from
common
import
is_bf16_supported
PARAMS_KEY
=
"params"
DROPOUT_KEY
=
"dropout"
...
...
@@ -59,7 +60,7 @@ class Net(nn.Module):
return
x
@
partial
(
jax
.
jit
)
@
jax
.
jit
def
train_step
(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
):
"""Computes gradients, loss and accuracy for a single batch."""
...
...
@@ -195,9 +196,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
)
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
train_and_evaluate
(
args
):
...
...
@@ -214,7 +214,12 @@ def train_and_evaluate(args):
mask_shape
=
[
args
.
batch_size
,
1
,
args
.
max_seq_len
,
args
.
max_seq_len
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
):
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
):
encoder
=
Net
(
num_embed
)
# We use nn.Embed, thus inputs need to be in int
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
...
@@ -309,6 +314,12 @@ def encoder_parser(args):
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
return
parser
.
parse_args
(
args
)
...
...
@@ -316,7 +327,8 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -329,10 +341,19 @@ class TestEncoder(unittest.TestCase):
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.45
and
actual
[
1
]
>
0.79
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
def
test_te_fp8
(
self
):
"""Test Transformer Engine with FP8"""
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.455
and
actual
[
1
]
>
0.79
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8
(
self
):
"""Test Transformer Engine with MXFP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.455
and
actual
[
1
]
>
0.79
...
...
examples/jax/mnist/test_single_gpu_mnist.py
View file @
a207db1d
...
...
@@ -5,6 +5,8 @@
import
argparse
import
unittest
from
functools
import
partial
import
sys
from
pathlib
import
Path
import
jax
import
jax.numpy
as
jnp
...
...
@@ -16,6 +18,11 @@ from flax.training import train_state
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
DIR
=
str
(
Path
(
__file__
).
resolve
().
parents
[
1
])
sys
.
path
.
append
(
str
(
DIR
))
from
encoder.common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
IMAGE_H
=
28
IMAGE_W
=
28
...
...
@@ -37,6 +44,7 @@ class Net(nn.Module):
else
:
nn_Dense
=
nn
.
Dense
# dtype is used for param init in TE but computation in Linen.nn
dtype
=
jnp
.
float32
if
self
.
use_te
else
jnp
.
bfloat16
x
=
nn
.
Conv
(
features
=
32
,
kernel_size
=
(
3
,
3
),
strides
=
1
,
dtype
=
jnp
.
bfloat16
)(
x
)
...
...
@@ -50,8 +58,8 @@ class Net(nn.Module):
x
=
nn_Dense
(
features
=
128
,
dtype
=
dtype
)(
x
)
x
=
nn
.
relu
(
x
)
x
=
nn
.
Dropout
(
rate
=
0.5
)(
x
,
deterministic
=
disable_dropout
)
x
=
nn_Dense
(
features
=
16
,
dtype
=
dtype
)(
x
)
x
=
nn_Dense
(
features
=
10
,
dtype
=
dtype
)(
x
)
x
=
nn_Dense
(
features
=
32
,
dtype
=
dtype
)(
x
)
x
=
nn_Dense
(
features
=
32
,
dtype
=
dtype
)(
x
)
assert
x
.
dtype
==
jnp
.
bfloat16
return
x
...
...
@@ -62,7 +70,7 @@ def apply_model(state, images, labels, var_collect, rngs=None):
def
loss_fn
(
var_collect
,
disable_dropout
=
False
):
logits
=
state
.
apply_fn
(
var_collect
,
images
,
disable_dropout
,
rngs
=
rngs
)
one_hot
=
jax
.
nn
.
one_hot
(
labels
,
10
)
one_hot
=
jax
.
nn
.
one_hot
(
labels
,
32
)
loss
=
jnp
.
mean
(
optax
.
softmax_cross_entropy
(
logits
=
logits
,
labels
=
one_hot
))
return
loss
,
logits
...
...
@@ -153,7 +161,7 @@ def get_datasets():
def
check_fp8
(
state
,
var_collect
,
input_shape
,
label_shape
):
"Check if model includes FP8."
assert
"f8_"
in
str
(
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
apply_model
)(
state
,
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
),
...
...
@@ -161,6 +169,7 @@ def check_fp8(state, var_collect, input_shape, label_shape):
var_collect
,
)
)
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
train_and_evaluate
(
args
):
...
...
@@ -179,7 +188,12 @@ def train_and_evaluate(args):
input_shape
=
[
args
.
batch_size
,
IMAGE_H
,
IMAGE_W
,
IMAGE_C
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
):
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
):
cnn
=
Net
(
args
.
use_te
)
var_collect
=
cnn
.
init
(
init_rngs
,
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
))
tx
=
optax
.
sgd
(
args
.
lr
,
args
.
momentum
)
...
...
@@ -276,6 +290,12 @@ def mnist_parser(args):
"It also enables Transformer Engine implicitly."
),
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
parser
.
add_argument
(
"--use-te"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use Transformer Engine"
)
...
...
@@ -286,7 +306,8 @@ def mnist_parser(args):
class
TestMNIST
(
unittest
.
TestCase
):
"""MNIST unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -298,13 +319,14 @@ class TestMNIST(unittest.TestCase):
"""Check If loss and accuracy match target"""
desired_traing_loss
=
0.055
desired_traing_accuracy
=
0.98
desired_test_loss
=
0.04
desired_test_loss
=
0.04
5
desired_test_accuracy
=
0.098
assert
actual
[
0
]
<
desired_traing_loss
assert
actual
[
1
]
>
desired_traing_accuracy
assert
actual
[
2
]
<
desired_test_loss
assert
actual
[
3
]
>
desired_test_accuracy
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
self
.
args
.
use_te
=
True
...
...
@@ -312,10 +334,19 @@ class TestMNIST(unittest.TestCase):
actual
=
train_and_evaluate
(
self
.
args
)
self
.
verify
(
actual
)
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
def
test_te_fp8
(
self
):
"""Test Transformer Engine with FP8"""
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
self
.
verify
(
actual
)
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8
(
self
):
"""Test Transformer Engine with MXFP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
self
.
verify
(
actual
)
...
...
qa/L0_jax_unittest/test.sh
View file @
a207db1d
...
...
@@ -20,16 +20,15 @@ pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3
install
pytest
==
8.2.1
||
error_exit
"Failed to install pytest"
:
${
TE_PATH
:
=/opt/transformerengine
}
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax
-k
'not distributed'
--ignore
=
$TE_PATH
/tests/jax/test_
praxis_lay
er
s
.py
||
test_fail
"test
_praxis_layers.py
"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax
-k
'not distributed'
--ignore
=
$TE_PATH
/tests/jax/test_
help
er.py
||
test_fail
"test
s/jax/*not_distributed_*
"
# Test without custom calls
NVTE_CUSTOM_CALLS_RE
=
""
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax/test_custom_call_compute.py
||
test_fail
"test_custom_call_compute.py"
NVTE_CUSTOM_CALLS_RE
=
""
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax/test_custom_call_compute.py
||
test_fail
"test_custom_call_compute.py
without TE custom calls
"
pip3
install
-r
$TE_PATH
/examples/jax/mnist/requirements.txt
||
error_exit
"Failed to install mnist requirements"
pip3
install
-r
$TE_PATH
/examples/jax/encoder/requirements.txt
||
error_exit
"Failed to install encoder requirements"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/mnist
||
test_fail
"test_mnist.py"
pip3
install
-r
$TE_PATH
/examples/jax/encoder/requirements.txt
||
error_exit
"Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py"
...
...
qa/L0_pytorch_unittest/test.sh
View file @
a207db1d
...
...
@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_parallel_cross_entropy.py
||
test_fail
"test_parallel_cross_entropy.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cpu_offloading.py
||
test_fail
"test_cpu_offloading.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn.py
||
test_fail
"test_fused_attn.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_
paged_attn
.py
||
test_fail
"test_
paged_attn
.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_
kv_cache
.py
||
test_fail
"test_
kv_cache
.py"
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
...
...
qa/L2_jax_unittest/test.sh
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set
-xe
pip
install
"nltk>=3.8.2"
pip
install
pytest
==
8.2.1
:
${
TE_PATH
:
=/opt/transformerengine
}
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax
-k
'not distributed'
--ignore
=
$TE_PATH
/tests/jax/test_praxis_layers.py
# Test without custom calls
NVTE_JAX_UNITTEST_LEVEL
=
"L2"
NVTE_CUSTOM_CALLS_RE
=
""
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax/test_custom_call_compute.py
pip
install
-r
$TE_PATH
/examples/jax/mnist/requirements.txt
pip
install
-r
$TE_PATH
/examples/jax/encoder/requirements.txt
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
tests/jax/distributed_test_base.py
View file @
a207db1d
...
...
@@ -82,7 +82,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'i32[1024]{0}',
'bf16[1024,1024]{0}'
"""
match
=
re
.
search
(
r
"(i|f)(\d+).*\[([0-9,]*)\]"
,
t
)
match
=
re
.
search
(
r
"(i|f
|u
)(\d+).*\[([0-9,]*)\]"
,
t
)
_
,
bits_of_type
,
shape
=
match
.
groups
()
bytes_of_type
=
int
(
bits_of_type
)
//
8
if
shape
==
""
:
...
...
tests/jax/test_custom_call_compute.py
View file @
a207db1d
This diff is collapsed.
Click to expand it.
tests/jax/test_distributed_fused_attn.py
View file @
a207db1d
...
...
@@ -6,7 +6,6 @@ import os
import
pytest
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
from
jax
import
random
from
distributed_test_base
import
(
generate_configs
,
...
...
@@ -104,7 +103,7 @@ class TestDistributedSelfAttn:
hidden
,
None
,
# no window
):
pytest
.
skip
(
f
"No FusedAttn backend found"
)
pytest
.
skip
(
"No FusedAttn backend found"
)
col_ref
=
self
.
generate_collectives_count_ref
(
mesh_shape
,
...
...
@@ -176,7 +175,7 @@ class TestDistributedCrossAttn:
hidden
,
None
,
# no window
):
pytest
.
skip
(
f
"No FusedAttn backend found"
)
pytest
.
skip
(
"No FusedAttn backend found"
)
col_ref
=
self
.
generate_collectives_count_ref
()
runner
=
FusedAttnRunner
(
...
...
@@ -256,7 +255,6 @@ class TestDistributedContextParallelSelfAttn:
dropout_prob
=
0.0
is_training
=
True
dp_size
,
cp_size
,
tp_size
=
mesh_shape
qkv_format
=
qkv_layout
.
get_qkv_format
()
batch
,
seqlen
,
num_head
,
hidden
=
data_shape
...
...
@@ -382,7 +380,7 @@ class TestDistributedContextParallelSelfAttn:
if
qkv_layout
.
is_thd
()
and
not
load_balanced
:
pytest
.
skip
(
"THD + ring doesn't support unbalanced context parallelism."
)
return
self
.
impl_test_context_parallel_attn
(
self
.
impl_test_context_parallel_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
...
...
@@ -396,6 +394,7 @@ class TestDistributedContextParallelSelfAttn:
CPStrategy
.
RING
,
)
del
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
return
class
TestReorderCausalLoadBalancing
:
...
...
tests/jax/test_distributed_layernorm.py
View file @
a207db1d
...
...
@@ -13,11 +13,30 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
from
distributed_test_base
import
generate_configs
,
generate_collectives_count
from
distributed_test_base
import
compare_ops
from
utils
import
pytest_parametrize_wrapper
from
transformer_engine.jax
import
fp8_autocast
from
transformer_engine.common
import
recipe
from
transformer_engine.jax.layernorm
import
layernorm
from
transformer_engine.jax.quantize
import
QuantizerFactory
,
ScalingMode
,
is_fp8_available
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
NORM_INPUT_SHAPES
=
{
"L0"
:
[[
64
,
64
]],
"L2"
:
[[
64
,
64
]],
}
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
if
is_mxfp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
class
TestDistributedLayernorm
:
...
...
@@ -41,25 +60,32 @@ class TestDistributedLayernorm:
return
(
x
,
gamma
,
beta
),
(
x_pspec
,
g_pspec
,
b_pspec
)
def
generate_collectives_count_ref
(
self
,
mesh_resource
,
ln_type
,
shape
,
dtype
):
def
generate_collectives_count_ref
(
self
,
mesh_resource
,
ln_type
,
shape
,
dtype
,
mesh_axes
,
fp8_recipe
):
jax_dtype
=
jax
.
dtypes
.
canonicalize_dtype
(
dtype
)
is_dp_enabled
=
mesh_resource
.
dp_resource
is
not
None
assert
ln_type
in
[
"layernorm"
,
"rmsnorm"
]
all_reduce_loss_bytes
=
4
# 1 * FP32
# for loss, dgamma and dbeta
weight_count
=
2
if
ln_type
==
"layernorm"
else
1
# TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
weight_count
=
2
if
(
ln_type
==
"layernorm"
and
"dp"
in
mesh_axes
)
else
1
allreduce_total_bytes
=
(
all_reduce_loss_bytes
+
weight_count
*
shape
[
-
1
]
*
jax_dtype
.
itemsize
)
other_bytes
=
0
if
fp8_recipe
==
recipe
.
MXFP8BlockScaling
()
and
"dp"
in
mesh_axes
:
other_bytes
=
384
# required for small scale shapes that require padding
return
generate_collectives_count
(
allreduce
=
allreduce_total_bytes
*
int
(
is_dp_enabled
),
allgather
=
0
,
other
=
0
allreduce
=
allreduce_total_bytes
*
int
(
is_dp_enabled
),
allgather
=
0
,
other
=
other_bytes
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
128
,
1024
],
[
32
,
1024
]])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"data_shape"
,
NORM_INPUT_SHAPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm
(
self
,
device_count
,
...
...
@@ -70,12 +96,19 @@ class TestDistributedLayernorm:
dtype
,
zero_centered_gamma
,
shard_weights
,
fp8_recipe
,
):
epsilon
=
1e-6
ln_type
=
"layernorm"
q_dtype
=
jnp
.
float8_e4m3fn
def
target_func
(
x
,
gamma
,
beta
):
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
beta
,
ln_type
,
zero_centered_gamma
,
epsilon
))
quantizer
=
QuantizerFactory
.
create_set
().
x
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
beta
,
ln_type
,
zero_centered_gamma
,
epsilon
,
quantizer
=
quantizer
)
)
def
ref_func
(
x
,
gamma
,
beta
):
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
...
...
@@ -92,11 +125,11 @@ class TestDistributedLayernorm:
data_shape
,
mesh_resource
,
dtype
,
shard_weights
)
collective_count_ref
=
self
.
generate_collectives_count_ref
(
mesh_resource
,
ln_type
,
data_shape
,
dtype
mesh_resource
,
ln_type
,
data_shape
,
dtype
,
mesh_axes
,
fp8_recipe
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
gamma_
=
jax
.
device_put
(
gamma
,
NamedSharding
(
mesh
,
g_pspec
))
beta_
=
jax
.
device_put
(
beta
,
NamedSharding
(
mesh
,
b_pspec
))
...
...
@@ -109,8 +142,8 @@ class TestDistributedLayernorm:
[
x_
,
gamma_
,
beta_
],
collective_count_ref
,
grad_args
=
(
0
,
1
,
2
),
metric_fwd_dtype
=
dtype
,
metric_bwd_dtype
=
dtype
,
metric_fwd_dtype
=
q_
dtype
,
metric_bwd_dtype
=
q_
dtype
,
in_shardings
=
(
x_pspec
,
g_pspec
,
b_pspec
),
out_shardings
=
(
None
,
(
x_pspec
,
g_pspec
,
b_pspec
)),
)
...
...
@@ -131,17 +164,28 @@ class TestDistributedLayernorm:
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
128
,
1024
],
[
32
,
1024
]])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"data_shape"
,
NORM_INPUT_SHAPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_rmsnorm
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
dtype
,
shard_weights
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
dtype
,
shard_weights
,
fp8_recipe
,
):
epsilon
=
1e-6
ln_type
=
"rmsnorm"
q_dtype
=
jnp
.
float8_e4m3fn
def
target_func
(
x
,
gamma
):
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
None
,
ln_type
,
False
,
epsilon
))
quantizer
=
QuantizerFactory
.
create_set
().
x
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
None
,
ln_type
,
False
,
epsilon
,
quantizer
=
quantizer
))
def
ref_func
(
x
,
gamma
):
x
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
...
...
@@ -154,11 +198,11 @@ class TestDistributedLayernorm:
data_shape
,
mesh_resource
,
dtype
,
shard_weights
)
collective_count_ref
=
self
.
generate_collectives_count_ref
(
mesh_resource
,
ln_type
,
data_shape
,
dtype
mesh_resource
,
ln_type
,
data_shape
,
dtype
,
mesh_axes
,
fp8_recipe
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
gamma_
=
jax
.
device_put
(
gamma
,
NamedSharding
(
mesh
,
g_pspec
))
...
...
@@ -170,8 +214,8 @@ class TestDistributedLayernorm:
[
x_
,
gamma_
],
collective_count_ref
,
grad_args
=
(
0
,
1
),
metric_fwd_dtype
=
dtype
,
metric_bwd_dtype
=
dtype
,
metric_fwd_dtype
=
q_
dtype
,
metric_bwd_dtype
=
q_
dtype
,
in_shardings
=
(
x_pspec
,
g_pspec
),
out_shardings
=
(
None
,
(
x_pspec
,
g_pspec
)),
)
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Callable
,
Sequence
,
Union
,
Optional
import
pytest
from
typing
import
Callable
,
List
,
Sequence
,
Union
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
from
jax.sharding
import
Mesh
,
NamedSharding
,
PartitionSpec
from
utils
import
(
assert_allclose
,
assert_tree_like_allclose
,
is_devices_enough
,
pytest_parametrize_wrapper
,
)
from
transformer_engine.
jax.fp8
import
FP8MetaPackage
,
FP8Hel
pe
r
from
transformer_engine.jax.
fp8
import
is_fp8_available
from
transformer_engine.
common
import
reci
pe
from
transformer_engine.jax.
quantize
import
is_fp8_available
,
ScalingMode
from
transformer_engine.jax
import
fp8_autocast
from
transformer_engine.jax.flax
import
LayerNormMLP
from
transformer_engine.jax.layernorm_mlp
import
fused_
layernorm_
fp8_
mlp
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
from
transformer_engine.jax.sharding
import
(
HIDDEN_AXES
,
HIDDEN_TP_AXES
,
...
...
@@ -26,17 +32,25 @@ from transformer_engine.jax.sharding import (
W_JOINED_AXES
,
)
from
transformer_engine.jax.sharding
import
MeshResource
from
transformer_engine.jax.quantize
import
QuantizerFactory
from
utils
import
assert_allclose
,
assert_tree_like_allclose
,
is_devices_enough
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
if
is_mxfp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float16
]
INPUT_SHAPE
=
[[
64
,
128
,
32
]]
# [batch, seqlen, hidden_in]
INPUT_SHAPE
=
[[
2
,
64
,
64
]]
# [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_TP_AXES
,
HIDDEN_AXES
)
DOT_1_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_AXES
)
DOT_2_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_TP_AXES
)
INTERMEDIATE
=
1
6
INTERMEDIATE
=
6
4
# Only test with FSDP and TP as DP is not used
...
...
@@ -66,13 +80,13 @@ class TestDistributedLayernormMLP:
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
batch
,
seqlen
,
hidden_in
),
dtype
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
hidden_in
,),
dtype
=
dtype
)
k1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
hidden_in
,
len
(
activation_type
)
,
INTERMEDIATE
),
dtype
subkeys
[
1
],
(
hidden_in
,
len
(
activation_type
)
*
INTERMEDIATE
),
dtype
)
/
jnp
.
sqrt
(
hidden_in
)
k2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
INTERMEDIATE
,
hidden_out
),
dtype
)
/
jnp
.
sqrt
(
INTERMEDIATE
)
if
use_bias
:
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
,
INTERMEDIATE
),
dtype
)
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
*
INTERMEDIATE
),
dtype
)
b2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
hidden_out
,),
dtype
)
else
:
b1
=
None
...
...
@@ -86,35 +100,13 @@ class TestDistributedLayernormMLP:
ln_scale
:
jnp
.
ndarray
,
kernel_1
:
jnp
.
ndarray
,
kernel_2
:
jnp
.
ndarray
,
bias_1
:
jnp
.
ndarray
,
bias_2
:
jnp
.
ndarray
,
amax_list_1
:
List
[
jnp
.
ndarray
],
amax_list_2
:
List
[
jnp
.
ndarray
],
scale_list_1
:
List
[
jnp
.
ndarray
],
scale_list_2
:
List
[
jnp
.
ndarray
],
bias_1
:
Optional
[
jnp
.
ndarray
],
bias_2
:
Optional
[
jnp
.
ndarray
],
layernorm_type
:
str
=
"rmsnorm"
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
use_bias
:
bool
=
True
,
multi_gpus
:
bool
=
False
,
)
->
jnp
.
ndarray
:
fp8_meta_pkg1
=
FP8MetaPackage
(
amax_list_1
[
0
],
scale_list_1
[
0
],
amax_list_1
[
1
],
scale_list_1
[
1
],
amax_list_1
[
2
],
scale_list_1
[
2
],
)
fp8_meta_pkg2
=
FP8MetaPackage
(
amax_list_2
[
0
],
scale_list_2
[
0
],
amax_list_2
[
1
],
scale_list_2
[
1
],
amax_list_2
[
2
],
scale_list_2
[
2
],
)
if
multi_gpus
:
layernorm_input_axes
=
LAYERNORM_INPUT_AXES
dot_1_input_axes
=
DOT_1_INPUT_AXES
...
...
@@ -124,83 +116,64 @@ class TestDistributedLayernormMLP:
dot_1_input_axes
=
None
dot_2_input_axes
=
None
quantizer_sets
=
QuantizerFactory
.
create_set
(
n_quantizer_sets
=
2
)
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return
jnp
.
mean
(
fused_
layernorm_
fp8_
mlp
(
layernorm_mlp
(
x
,
ln_scale
,
None
,
[
kernel_1
,
kernel_2
],
[
bias_1
,
bias_2
],
[
fp8_meta_pkg1
,
fp8_meta_pkg2
],
layernorm_type
,
layer
norm_input_axes
=
layernorm_input_axes
,
norm_input_axes
=
layernorm_input_axes
,
dot_1_input_axes
=
dot_1_input_axes
,
dot_2_input_axes
=
dot_2_input_axes
,
activation_type
=
activation_type
,
use_bias
=
use_bia
s
,
quantizer_sets
=
quantizer_set
s
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
.
mark
.
parametrize
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm_fp8_mlp_primitive
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
):
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
layernorm_type
=
"rmsnorm"
fp8_amax_list_1
=
[
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
]
fp8_amax_list_2
=
[
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
]
fp8_scale_list_1
=
[
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
]
fp8_scale_list_2
=
[
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
]
inputs
=
[
x
,
gamma
,
k1
,
k2
,
b1
,
b2
]
=
self
.
generate_inputs
(
input_shape
,
activation_type
,
use_bias
,
dtype
)
inputs
=
[
*
inputs
,
fp8_amax_list_1
,
fp8_amax_list_2
,
fp8_scale_list_1
,
fp8_scale_list_2
]
static_inputs
=
[
layernorm_type
,
activation_type
,
use_bias
]
static_inputs
=
[
layernorm_type
,
activation_type
]
value_and_grad_func
=
jax
.
value_and_grad
(
self
.
layernorm_fp8_mlp_prim_func
,
argnums
=
range
(
len
(
inputs
))
)
# Single GPU
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
):
single_jitter
=
jax
.
jit
(
value_and_grad_func
,
static_argnums
=
range
(
len
(
inputs
),
len
(
static_inputs
)
+
len
(
inputs
))
value_and_grad_func
,
static_argnums
=
range
(
len
(
inputs
),
len
(
static_inputs
)
+
len
(
inputs
)),
)
with
fp8_autocast
(
enabled
=
True
):
single_fwd
,
single_grads
=
single_jitter
(
*
inputs
,
*
static_inputs
)
# Multi GPUs
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
True
,
mesh_resource
=
mesh_resource
):
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
None
,
"tp"
))
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
"tp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
k2_
=
jax
.
device_put
(
k2
,
k2_sharding
)
if
use_bias
:
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
"tp"
))
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
))
b1_
=
jax
.
device_put
(
b1
,
b1_sharding
)
else
:
b1_sharding
=
b1_
=
None
...
...
@@ -208,7 +181,7 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2
, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
# b2
in_shardings
=
(
None
,
None
,
...
...
@@ -216,14 +189,10 @@ class TestDistributedLayernormMLP:
k2_sharding
,
b1_sharding
,
None
,
None
,
None
,
None
,
None
,
)
out_shardings
=
(
None
,
(
None
,
None
,
k1_sharding
,
k2_sharding
,
b1_sharding
,
None
,
None
,
None
,
None
,
None
),
(
None
,
None
,
k1_sharding
,
k2_sharding
,
b1_sharding
,
None
),
)
multi_jitter
=
jax
.
jit
(
...
...
@@ -245,15 +214,42 @@ class TestDistributedLayernormMLP:
m_grad
,
s_grad
,
dtype
=
dtype
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
)
else
:
is_gated
=
len
(
activation_type
)
>
1
rtol
=
None
atol
=
None
if
is_gated
:
if
dtype
==
jnp
.
bfloat16
:
if
i
==
2
:
rtol
=
800
atol
=
9e-2
if
i
==
4
:
atol
=
300
rtol
=
1e-1
if
dtype
==
jnp
.
float16
:
if
i
==
1
:
# gamma
rtol
=
200
atol
=
1e-2
if
i
==
2
:
rtol
=
2000
atol
=
7e-2
if
i
==
4
and
fp8_recipe
==
recipe
.
MXFP8BlockScaling
():
# bias_1
# Accumulating dbias across a large tensor introduces a larger difference
rtol
=
200
atol
=
4e-2
if
i
==
4
and
fp8_recipe
==
recipe
.
DelayedScaling
():
rtol
=
2200
atol
=
9e-2
assert_allclose
(
multi_grads
[
i
],
single_grads
[
i
],
dtype
=
dtype
,
rtol
=
rtol
,
atol
=
atol
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
def
_test_layernorm_mlp
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
,
fp8_recipe
=
None
):
batch
,
seqlen
,
hidden_in
=
input_shape
layernorm_type
=
"rmsnorm"
...
...
@@ -265,7 +261,7 @@ class TestDistributedLayernormMLP:
init_rngs
=
{
"params"
:
subkeys
[
1
]}
# Single GPUs
with
fp8_autocast
(
enabled
=
use_fp8
):
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
ln_mlp_single
=
LayerNormMLP
(
layernorm_type
=
layernorm_type
,
transpose_batch_sequence
=
False
,
# input: [batch, seqlen, hidden]
...
...
@@ -282,7 +278,9 @@ class TestDistributedLayernormMLP:
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
use_fp8
,
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
ln_mlp_sharded
=
LayerNormMLP
(
layernorm_type
=
layernorm_type
,
transpose_batch_sequence
=
False
,
...
...
@@ -310,25 +308,30 @@ class TestDistributedLayernormMLP:
assert_allclose
(
ln_out_sharded
,
ln_out_single
,
dtype
=
dtype
)
assert_allclose
(
mlp_out_sharded
,
mlp_out_single
,
dtype
=
dtype
)
@
pytest
.
mark
.
parametrize
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)
,
(
"gelu"
,
"gelu"
)
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
_
parametrize
_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
_
parametrize
_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
_
parametrize
_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest
_
parametrize
_wrapper
(
"dtype"
,
DTYPES
)
@
pytest
_
parametrize
_wrapper
(
"use_bias"
,
[
True
,
False
])
def
test_layernorm_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
):
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
False
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
),
(
"gelu"
,
"gelu"
)])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_layernorm_fp8_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
):
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
True
)
# TODO: debug
# @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
# @pytest_parametrize_wrapper(
# "activation_type", [("gelu",), ("gelu", "linear")]
# )
# @pytest_parametrize_wrapper("use_bias", [True, False])
# @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
# @pytest_parametrize_wrapper("dtype", DTYPES)
# @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
# def test_layernorm_fp8_mlp_layer(
# self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
# ):
# self._test_layernorm_mlp(
# mesh_config, activation_type, use_bias, input_shape, dtype,
# use_fp8=True, fp8_recipe=fp8_recipe
# )
tests/jax/test_distributed_softmax.py
View file @
a207db1d
...
...
@@ -3,8 +3,8 @@
# See LICENSE for license information.
import
warnings
import
pytest
from
functools
import
partial
import
pytest
import
jax
import
jax.numpy
as
jnp
...
...
tests/jax/test_helper.py
View file @
a207db1d
...
...
@@ -13,13 +13,13 @@ from utils import assert_allclose
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.common.recipe
import
Format
as
FP8Format
from
transformer_engine.jax
import
fp8_autocast
,
get_delayed_scaling
from
transformer_engine.jax.
fp8
import
FP8Helper
,
is_fp8_available
,
AmaxComputeAlgo
from
transformer_engine.jax.
quantize
import
QuantizeConfig
,
is_fp8_available
,
AmaxComputeAlgo
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
is_fp8_supported
,
reason
=
is_fp8_available
()
class
Test
FP8Helper
(
unittest
.
TestCase
):
class
Test
QuantizeConfig
(
unittest
.
TestCase
):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_initialize
(
self
):
...
...
@@ -27,30 +27,30 @@ class TestFP8Helper(unittest.TestCase):
fp8_format
=
FP8Format
.
E4M3
amax_history_len
=
10
FP8Helper
.
initialize
(
QuantizeConfig
.
initialize
(
margin
=
margin
,
fp8_format
=
fp8_format
,
amax_history_len
=
amax_history_len
)
self
.
assertEqual
(
FP8Helper
.
MARGIN
,
QuantizeConfig
.
MARGIN
,
margin
,
f
"
FP8Helper
.MARGIN initialization failed, should be
{
margin
}
"
f
" but got
{
FP8Helper
.
MARGIN
}
."
,
f
"
QuantizeConfig
.MARGIN initialization failed, should be
{
margin
}
"
f
" but got
{
QuantizeConfig
.
MARGIN
}
."
,
)
self
.
assertEqual
(
FP8Helper
.
FP8_FORMAT
,
QuantizeConfig
.
FP8_FORMAT
,
fp8_format
,
f
"
FP8Helper
.FP8_FORMAT initialization failed, should be
{
fp8_format
}
"
f
" but got
{
FP8Helper
.
FP8_FORMAT
}
."
,
f
"
QuantizeConfig
.FP8_FORMAT initialization failed, should be
{
fp8_format
}
"
f
" but got
{
QuantizeConfig
.
FP8_FORMAT
}
."
,
)
self
.
assertEqual
(
FP8Helper
.
AMAX_HISTORY_LEN
,
QuantizeConfig
.
AMAX_HISTORY_LEN
,
amax_history_len
,
f
"
FP8Helper
.AMAX_HISTORY_LEN initialization failed, should be
{
amax_history_len
}
"
f
" but got
{
FP8Helper
.
AMAX_HISTORY_LEN
}
."
,
f
"
QuantizeConfig
.AMAX_HISTORY_LEN initialization failed, should be
{
amax_history_len
}
"
f
" but got
{
QuantizeConfig
.
AMAX_HISTORY_LEN
}
."
,
)
FP8Helper
.
finalize
()
QuantizeConfig
.
finalize
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_update_collections
(
self
):
...
...
@@ -61,12 +61,12 @@ class TestFP8Helper(unittest.TestCase):
"test1"
:
original_val
,
"test2"
:
original_val
,
}
updated_state
=
FP8Helper
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
updated_state
=
QuantizeConfig
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
original_state
=
flax
.
core
.
frozen_dict
.
FrozenDict
(
original_state
)
updated_state
=
FP8Helper
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
updated_state
=
QuantizeConfig
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
...
...
@@ -74,7 +74,7 @@ class TestFP8Helper(unittest.TestCase):
class
TestFP8Functions
(
unittest
.
TestCase
):
def
_check_defult_state
(
self
):
self
.
assertFalse
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
def
_compare_delay_scaling
(
self
,
ref
,
test
):
self
.
assertTrue
(
ref
.
margin
==
test
.
margin
)
...
...
@@ -84,32 +84,32 @@ class TestFP8Functions(unittest.TestCase):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast
(
self
):
FP8Helper
.
finalize
()
# Ensure the testing not affect by previous tests.
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_defult_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
DelayedScaling
()):
self
.
assertFalse
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
DelayedScaling
())
self
.
_check_defult_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_defult_state
()
ds
=
DelayedScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_defult_state
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast_with_sharding_resource
(
self
):
FP8Helper
.
finalize
()
# Ensure the testing not affect by previous tests.
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_defult_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
...
...
@@ -126,7 +126,7 @@ class TestFP8Functions(unittest.TestCase):
with
jax
.
sharding
.
Mesh
(
devices
,
(
"dp"
,
"tp"
)):
for
sr
in
mesh_s
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
,
mesh_resource
=
sr
):
self
.
assertTrue
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
assertEqual
(
sr
,
global_mesh_resource
())
...
...
tests/jax/test_layer.py
View file @
a207db1d
...
...
@@ -20,11 +20,14 @@ from utils import (
from
utils
import
DecoderLayer
as
RefDecoderLayer
from
utils
import
EncoderLayer
as
RefEncoderLayer
from
transformer_engine.common
.recipe
import
Format
from
transformer_engine.common
import
recipe
from
transformer_engine.jax.flax
import
TransformerLayer
,
TransformerLayerType
from
transformer_engine.jax.fp8
import
FP8Helper
,
is_fp8_available
is_fp8_supported
,
reason
=
is_fp8_available
()
from
transformer_engine.jax.quantize
import
(
QuantizeConfig
,
ScalingMode
,
is_fp8_available
,
update_collections
,
)
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"function"
)
...
...
@@ -35,12 +38,21 @@ def enable_fused_attn():
del
os
.
environ
[
"NVTE_FUSED_ATTN"
]
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
QUANTIZE_RECIPES
=
[]
""" Find supported scaling modes"""
if
is_fp8_supported
:
QUANTIZE_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
if
is_mxfp8_supported
:
QUANTIZE_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
DATA_SHAPE
=
[
# (batch, seqlen, emb_dim)
pytest
.
param
((
32
,
128
,
1024
),
id
=
"32-128-1024"
),
pytest
.
param
((
32
,
512
,
1024
),
id
=
"32-512-1024"
),
]
DTYPE
=
[
jnp
.
float32
,
jnp
.
bfloat16
]
FP8_FORMATS
=
[
Format
.
E4M3
,
Format
.
HYBRID
]
DTYPE
=
[
jnp
.
bfloat16
]
_KEY_OF_RESIDUAL_POST_LAYERNORM
=
"apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM
=
"output_layernorm"
...
...
@@ -80,27 +92,37 @@ BASE_ATTRS = {
}
ATTRS
=
[
# attrs0
{},
# attrs1
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
},
# attrs2
{
_KEY_OF_ZERO_CENTERED_GAMMA
:
True
,
_KEY_OF_LAYERNORM_EPS
:
1e-2
,
},
# attrs3
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_RESIDUAL_POST_LAYERNORM
:
True
},
# attrs4
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_OUTPUT_LAYERNORM
:
True
},
# attrs5
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_RESIDUAL_POST_LAYERNORM
:
True
,
_KEY_OF_OUTPUT_LAYERNORM
:
True
,
},
# attrs6
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_DROP_PATH
:
0.1
},
# attrs7
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_FUSE_QKV_PARAMS
:
False
},
# attrs8
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
},
# attrs9
{
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
...
...
@@ -109,12 +131,14 @@ ATTRS = [
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
_KEY_OF_USE_BIAS
:
True
,
},
# attrs10
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
},
# attrs11
{
_KEY_OF_NUM_HEADS
:
8
,
_KEY_OF_NUM_GQA_GROUPS
:
4
,
...
...
@@ -123,33 +147,7 @@ ATTRS = [
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,),
_KEY_OF_USE_BIAS
:
True
,
},
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,
"linear"
)),
},
{
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_HIDDEN_DROPOUT
:
0.8
,
_KEY_OF_INTERMEDIATE_DROPOUT
:
0.5
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,
"linear"
)),
_KEY_OF_USE_BIAS
:
True
,
},
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,
"linear"
)),
},
{
_KEY_OF_NUM_HEADS
:
8
,
_KEY_OF_NUM_GQA_GROUPS
:
4
,
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"layernorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,)),
_KEY_OF_USE_BIAS
:
True
,
},
# attrs12
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
...
...
@@ -158,12 +156,14 @@ ATTRS = [
_KEY_OF_ROPE_GROUP_METHOD
:
"consecutive"
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs13
{
_KEY_OF_TRANSPOSE_BS
:
True
,
_KEY_OF_ENABLE_ROPE
:
True
,
_KEY_OF_ROPE_GROUP_METHOD
:
"consecutive"
,
_KEY_OF_USE_BIAS
:
True
,
},
# attrs14
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_LAYERNORM_TYPE
:
"layernorm"
,
...
...
@@ -173,6 +173,7 @@ ATTRS = [
_KEY_OF_USE_BIAS
:
True
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs15
{
_KEY_OF_TRANSPOSE_BS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
...
...
@@ -180,26 +181,32 @@ ATTRS = [
_KEY_OF_ROPE_GROUP_METHOD
:
"alternate"
,
_KEY_OF_USE_BIAS
:
True
,
},
# attrs16
{
_KEY_OF_HIDDEN_DROPOUT
:
0.3
,
_KEY_OF_HIDDEN_DROPOUT_DIMS
:
(
0
,),
_KEY_OF_INTERMEDIATE_DROPOUT
:
0.5
,
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS
:
(
1
,),
},
# attrs17
{
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"padding"
,
_KEY_OF_USE_BIAS
:
True
,
},
# attrs18
{
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_BIAS_TYPE
:
"no_bias"
,
},
# attrs19
{
_KEY_OF_ATTENTION_DROPOUT
:
0.3
,
},
# attrs20
{
_KEY_OF_MLP_ACTIVATIONS
:
((
"relu"
,
"relu"
)),
},
# attrs21
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
...
...
@@ -207,6 +214,7 @@ ATTRS = [
_KEY_OF_WINDOW_SIZE
:
(
64
,
0
),
# Left size must < DATA_SHAPE seqlen
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs22
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
...
...
@@ -296,20 +304,24 @@ class BaseRunner:
ref_params
,
test_params
=
self
.
_sync_params
(
ref_params
,
test_params
)
if
FP8Helper
.
is_fp8_enabled
():
if
QuantizeConfig
.
is_fp8_enabled
():
for
_
in
range
(
4
):
_
,
tmp_grad
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
3
,),
has_aux
=
False
)(
_
,
updated_state
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
3
,),
has_aux
=
False
)(
inputs
,
test_masks
,
test_params
,
test_others
,
test_layer
,
)
_
,
fp8_meta_grad
=
flax
.
core
.
pop
(
tmp_grad
[
0
],
FP8Helper
.
FP8_COLLECTION_NAME
)
test_others
=
FP8Helper
.
update_collections
(
{
FP8Helper
.
FP8_COLLECTION_NAME
:
fp8_meta_grad
},
test_others
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
_
,
updated_quantize_meta
=
flax
.
core
.
pop
(
updated_state
[
0
],
QuantizeConfig
.
COLLECTION_NAME
)
test_others
=
update_collections
(
{
QuantizeConfig
.
COLLECTION_NAME
:
updated_quantize_meta
},
test_others
)
del
tmp_grad
,
fp8_meta_grad
del
updated_quantize_meta
del
updated_state
grad_fn
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
0
,
2
),
has_aux
=
False
)
...
...
@@ -436,29 +448,29 @@ class BaseTester:
def
test_forward
(
self
,
data_shape
,
dtype
,
attrs
):
"""Test normal datatype forward"""
FP8Helper
.
finalize
()
# Ensure FP8 disabled.
QuantizeConfig
.
finalize
()
# Ensure FP8 disabled.
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
)
def
test_backward
(
self
,
data_shape
,
dtype
,
attrs
):
"""Test normal datatype backward"""
FP8Helper
.
finalize
()
# Ensure FP8 disabled.
QuantizeConfig
.
finalize
()
# Ensure FP8 disabled.
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fp8_
format"
,
FP8_FORMAT
S
)
def
test_forward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
format
):
@
pytest
.
mark
.
parametrize
(
"fp8_
recipe"
,
QUANTIZE_RECIPE
S
)
def
test_forward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
recipe
):
"""Test forward with fp8 enabled"""
FP8Helper
.
initialize
(
fp8_
format
=
fp8_format
)
QuantizeConfig
.
initialize
(
fp8_
recipe
=
fp8_recipe
)
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
FP8Helper
.
finalize
()
QuantizeConfig
.
finalize
()
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fp8_
format"
,
FP8_FORMAT
S
)
def
test_backward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
format
):
@
pytest
.
mark
.
parametrize
(
"fp8_
recipe"
,
QUANTIZE_RECIPE
S
)
def
test_backward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
recipe
):
"""Test backward with fp8 enabled"""
FP8Helper
.
initialize
(
fp8_
format
=
fp8_format
)
QuantizeConfig
.
initialize
(
fp8_
recipe
=
fp8_recipe
)
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
FP8Helper
.
finalize
()
QuantizeConfig
.
finalize
()
class
TestEncoderLayer
(
BaseTester
):
...
...
tests/jax/test_praxis_layers.py
deleted
100644 → 0
View file @
fbee8990
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment