Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a207db1d
Commit
a207db1d
authored
Apr 01, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
fbee8990
69365f88
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1659 additions
and
2477 deletions
+1659
-2477
docs/api/pytorch.rst
docs/api/pytorch.rst
+1
-0
examples/jax/encoder/common.py
examples/jax/encoder/common.py
+20
-0
examples/jax/encoder/run_test_multiprocessing_encoder.sh
examples/jax/encoder/run_test_multiprocessing_encoder.sh
+7
-1
examples/jax/encoder/test_model_parallel_encoder.py
examples/jax/encoder/test_model_parallel_encoder.py
+63
-22
examples/jax/encoder/test_multigpu_encoder.py
examples/jax/encoder/test_multigpu_encoder.py
+45
-15
examples/jax/encoder/test_multiprocessing_encoder.py
examples/jax/encoder/test_multiprocessing_encoder.py
+59
-31
examples/jax/encoder/test_single_gpu_encoder.py
examples/jax/encoder/test_single_gpu_encoder.py
+31
-10
examples/jax/mnist/test_single_gpu_mnist.py
examples/jax/mnist/test_single_gpu_mnist.py
+41
-10
qa/L0_jax_unittest/test.sh
qa/L0_jax_unittest/test.sh
+3
-4
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-1
qa/L2_jax_unittest/test.sh
qa/L2_jax_unittest/test.sh
+23
-0
tests/jax/distributed_test_base.py
tests/jax/distributed_test_base.py
+1
-1
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+1107
-747
tests/jax/test_distributed_fused_attn.py
tests/jax/test_distributed_fused_attn.py
+4
-5
tests/jax/test_distributed_layernorm.py
tests/jax/test_distributed_layernorm.py
+65
-21
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+102
-99
tests/jax/test_distributed_softmax.py
tests/jax/test_distributed_softmax.py
+1
-1
tests/jax/test_helper.py
tests/jax/test_helper.py
+22
-22
tests/jax/test_layer.py
tests/jax/test_layer.py
+63
-51
tests/jax/test_praxis_layers.py
tests/jax/test_praxis_layers.py
+0
-1436
No files found.
docs/api/pytorch.rst
View file @
a207db1d
...
...
@@ -32,6 +32,7 @@ pyTorch
:members: forward, set_context_parallel_group, set_tensor_parallel_group
.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length)
:members: reset, allocate_memory, pre_step, get_seqlens_pre_step, convert_paged_to_nonpaged, step
.. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker()
:members: reset, get_states, set_states, add, fork
...
...
examples/jax/encoder/common.py
View file @
a207db1d
...
...
@@ -4,7 +4,9 @@
"""Shared functions for the encoder tests"""
from
functools
import
lru_cache
import
transformer_engine
from
transformer_engine_jax
import
get_device_compute_capability
from
transformer_engine.common
import
recipe
@
lru_cache
...
...
@@ -19,3 +21,21 @@ def is_fp8_supported():
"""Return if FP8 has hardware supported"""
gpu_arch
=
get_device_compute_capability
(
0
)
return
gpu_arch
>=
90
@
lru_cache
def
is_mxfp8_supported
():
"""Return if FP8 has hardware supported"""
gpu_arch
=
get_device_compute_capability
(
0
)
return
gpu_arch
>=
100
def
get_fp8_recipe_from_name_string
(
name
:
str
):
"""Query recipe from a given name string"""
match
name
:
case
"DelayedScaling"
:
return
recipe
.
DelayedScaling
()
case
"MXFP8BlockScaling"
:
return
recipe
.
MXFP8BlockScaling
()
case
_
:
raise
ValueError
(
f
"Invalid fp8_recipe, got
{
name
}
"
)
examples/jax/encoder/run_test_multiprocessing_encoder.sh
View file @
a207db1d
...
...
@@ -12,6 +12,12 @@ wait
for
i
in
$(
seq
0
$((
$NUM_GPUS
-
1
))
)
do
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8
--num-process
=
$NUM_GPUS
--process-id
=
$i
&
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8
--num-process
=
$NUM_GPUS
--process-id
=
$i
&
done
wait
for
i
in
$(
seq
0
$((
$NUM_GPUS
-
1
))
)
do
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8
--num-process
=
$NUM_GPUS
--process-id
=
$i
&
done
wait
examples/jax/encoder/test_model_parallel_encoder.py
View file @
a207db1d
...
...
@@ -19,10 +19,11 @@ from flax.training import train_state
from
jax.experimental
import
mesh_utils
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
from
common
import
is_bf16_supported
DEVICE_DP_AXIS
=
"data"
DEVICE_TP_AXIS
=
"model"
...
...
@@ -217,9 +218,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
)
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
...
...
@@ -272,6 +272,19 @@ def train_and_evaluate(args):
args
.
test_batch_size
%
num_gpu_dp
==
0
),
f
"Test batch size needs to be multiple of
{
num_gpu_dp
}
"
if
args
.
fp8_recipe
==
"MXFP8BlockScaling"
:
assert
(
args
.
batch_size
/
num_gpu_dp
%
32
==
0
),
"Batch size needs to be multiple of 32 for MXFP8"
assert
(
args
.
test_batch_size
/
num_gpu_dp
%
32
==
0
),
"Test batch size needs to be multiple of 32 for MXFP8"
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
)
...
...
@@ -287,7 +300,9 @@ def train_and_evaluate(args):
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
args
.
use_fp8
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
)
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
):
encoder
=
Net
(
num_embed
,
args
.
enable_sp
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
...
@@ -371,21 +386,21 @@ def encoder_parser(args):
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
help
=
"input batch size for training (default:
64
)"
,
help
=
"input batch size for training (default:
128
)"
,
)
parser
.
add_argument
(
"--test-batch-size"
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
help
=
"input batch size for testing (default:
64
)"
,
help
=
"input batch size for testing (default:
128
)"
,
)
parser
.
add_argument
(
"--max-seq-len"
,
type
=
int
,
default
=
32
,
default
=
64
,
metavar
=
"N"
,
help
=
"maximum sequence length (default: 32)"
,
)
...
...
@@ -416,6 +431,12 @@ def encoder_parser(args):
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
parser
.
add_argument
(
"--enable-sp"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enable sequence parallelism."
)
...
...
@@ -426,7 +447,8 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -437,29 +459,48 @@ class TestEncoder(unittest.TestCase):
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.45
and
actual
[
1
]
>
0.79
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
def
test_te_fp8
(
self
):
"""Test Transformer Engine with FP8"""
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_
reason
)
def
test_te_
mx
fp8
(
self
):
"""Test Transformer Engine with
MX
FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
45
5
and
actual
[
1
]
>
0.7
85
assert
actual
[
0
]
<
0.5
0
and
actual
[
1
]
>
0.7
6
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16_sp
(
self
):
def
test_te_bf16_
with_
sp
(
self
):
"""Test Transformer Engine with BF16 + SP"""
self
.
args
.
enable_sp
=
True
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.45
and
actual
[
1
]
>
0.79
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8_with_sp
(
self
):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self
.
args
.
enable_sp
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
def
test_te_fp8_sp
(
self
):
"""Test Transformer Engine with FP8 + SP"""
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_
reason
)
def
test_te_
mx
fp8_
with_
sp
(
self
):
"""Test Transformer Engine with
MX
FP8 + SP"""
self
.
args
.
enable_sp
=
True
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.
45
5
and
actual
[
1
]
>
0.7
85
assert
actual
[
0
]
<
0.5
0
and
actual
[
1
]
>
0.7
6
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_multigpu_encoder.py
View file @
a207db1d
...
...
@@ -19,10 +19,11 @@ from flax.training import train_state
from
jax.experimental
import
mesh_utils
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
from
common
import
is_bf16_supported
DEVICE_DP_AXIS
=
"data"
PARAMS_KEY
=
"params"
...
...
@@ -198,9 +199,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
)
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
...
...
@@ -243,6 +243,18 @@ def train_and_evaluate(args):
num_gpu
=
jax
.
local_device_count
()
assert
args
.
batch_size
%
num_gpu
==
0
,
f
"Batch size needs to be multiple of
{
num_gpu
}
"
assert
args
.
test_batch_size
%
num_gpu
==
0
,
f
"Test batch size needs to be multiple of
{
num_gpu
}
"
if
args
.
fp8_recipe
==
"MXFP8BlockScaling"
:
assert
(
args
.
batch_size
/
num_gpu
%
32
==
0
),
"Batch size needs to be multiple of 32 for MXFP8"
assert
(
args
.
test_batch_size
/
num_gpu
%
32
==
0
),
"Test batch size needs to be multiple of 32 for MXFP8"
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu
,))
with
jax
.
sharding
.
Mesh
(
devices
=
device_mesh
,
axis_names
=
(
DEVICE_DP_AXIS
,))
as
mesh
:
...
...
@@ -257,7 +269,9 @@ def train_and_evaluate(args):
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
args
.
use_fp8
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
None
,
None
,
None
)
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
None
,
None
,
None
),
):
encoder
=
Net
(
num_embed
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
...
@@ -344,16 +358,16 @@ def encoder_parser(args):
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
128
,
default
=
256
,
metavar
=
"N"
,
help
=
"input batch size for training (default:
128
)"
,
help
=
"input batch size for training (default:
256
)"
,
)
parser
.
add_argument
(
"--test-batch-size"
,
type
=
int
,
default
=
128
,
default
=
256
,
metavar
=
"N"
,
help
=
"input batch size for testing (default:
128
)"
,
help
=
"input batch size for testing (default:
256
)"
,
)
parser
.
add_argument
(
"--max-seq-len"
,
...
...
@@ -389,6 +403,12 @@ def encoder_parser(args):
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
return
parser
.
parse_args
(
args
)
...
...
@@ -396,7 +416,8 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -407,14 +428,23 @@ class TestEncoder(unittest.TestCase):
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.50
and
actual
[
1
]
>
0.76
assert
actual
[
0
]
<
0.535
and
actual
[
1
]
>
0.73
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.535
and
actual
[
1
]
>
0.73
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
def
test_te_fp8
(
self
):
"""Test Transformer Engine with FP8"""
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_
reason
)
def
test_te_
mx
fp8
(
self
):
"""Test Transformer Engine with
MX
FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.5
0
and
actual
[
1
]
>
0.7
6
assert
actual
[
0
]
<
0.5
35
and
actual
[
1
]
>
0.7
3
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_multiprocessing_encoder.py
View file @
a207db1d
...
...
@@ -21,9 +21,15 @@ from flax.training import train_state
from
jax.experimental
import
mesh_utils
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
common
import
is_bf16_supported
,
is_fp8_supported
from
common
import
(
is_bf16_supported
,
is_fp8_supported
,
is_mxfp8_supported
,
get_fp8_recipe_from_name_string
,
)
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
os
.
environ
[
"CUDA_DEVICE_ORDER"
]
=
"PCI_BUS_ID"
...
...
@@ -298,9 +304,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
)
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
get_params_sharding
(
sharding_rules
,
abs_var_collect
,
mesh
):
...
...
@@ -359,10 +364,16 @@ def train_and_evaluate(args):
num_gpu_dp
=
1
num_gpu_tp
=
1
assert
args
.
batch_size
%
num_gpu_dp
==
0
,
f
"Batch size needs to be multiple of
{
num_gpu_dp
}
"
assert
(
args
.
test_batch_size
%
num_gpu_dp
==
0
),
f
"Test batch size needs to be multiple of
{
num_gpu_dp
}
"
if
args
.
fp8_recipe
==
"MXFP8BlockScaling"
:
assert
args
.
batch_size
%
32
==
0
,
"Batch size needs to be multiple of 32 for MXFP8"
assert
(
args
.
test_batch_size
%
32
==
0
),
"Test batch size needs to be multiple of 32 for MXFP8"
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
device_mesh
=
mesh_utils
.
create_device_mesh
((
num_gpu_dp
,
num_gpu_tp
))
with
jax
.
sharding
.
Mesh
(
...
...
@@ -379,7 +390,9 @@ def train_and_evaluate(args):
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
args
.
use_fp8
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
)
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
te
.
MeshResource
(
DEVICE_DP_AXIS
,
DEVICE_TP_AXIS
,
None
,
None
),
):
encoder
=
Net
(
num_embed
)
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
...
@@ -482,23 +495,23 @@ def encoder_parser(args):
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
help
=
"input batch size for training (default:
64
)"
,
help
=
"input batch size for training (default:
128
)"
,
)
parser
.
add_argument
(
"--test-batch-size"
,
type
=
int
,
default
=
64
,
default
=
128
,
metavar
=
"N"
,
help
=
"input batch size for testing (default:
64
)"
,
help
=
"input batch size for testing (default:
128
)"
,
)
parser
.
add_argument
(
"--max-seq-len"
,
type
=
int
,
default
=
32
,
default
=
64
,
metavar
=
"N"
,
help
=
"maximum sequence length (default:
32
)"
,
help
=
"maximum sequence length (default:
64
)"
,
)
parser
.
add_argument
(
"--epochs"
,
...
...
@@ -527,13 +540,19 @@ def encoder_parser(args):
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
parser
.
add_argument
(
"--coordinator-address"
,
type
=
str
,
default
=
"127.0.0.1:1234"
,
help
=
(
"the IP address of process 0 and a port on
which that"
" process should launch a coordinator service
(default:"
"the IP address of process 0 and a port on which that"
" process should launch a coordinator service (default:"
" 127.0.0.1:1234)"
),
)
...
...
@@ -554,37 +573,46 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
gpu_has_fp8
=
is_fp8_supported
()
gpu_has_bf16
=
is_bf16_supported
()
def
exec
(
self
,
use_fp8
):
def
exec
(
self
,
use_fp8
,
fp8_recipe
):
"""Run 3 epochs for testing"""
args
=
encoder_parser
([])
num_gpu
=
self
.
num_process
tp_size
=
2
if
num_gpu
>
1
and
num_gpu
%
2
==
0
else
1
dp_size
=
num_gpu
//
tp_size
batch_size
=
64
//
dp_size
assert
args
.
batch_size
%
dp_size
==
0
,
f
"Batch size needs to be multiple of
{
dp_size
}
"
batch_size
=
args
.
batch_size
//
dp_size
args
.
use_fp8
=
use_fp8
args
.
batch_size
=
batch_size
args
.
test_batch_size
=
batch_size
args
.
num_process
=
num_gpu
args
.
process_id
=
self
.
process_id
args
.
fp8_recipe
=
fp8_recipe
return
train_and_evaluate
(
args
)
@
unittest
.
skipIf
(
not
gpu_has_bf16
,
"Device compute capability 8.0+ is required for BF16"
)
@
unittest
.
skipIf
(
not
is_bf16_supported
()
,
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
result
=
self
.
exec
(
False
)
assert
result
[
0
]
<
0.45
and
result
[
1
]
>
0.79
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
"Device compute capability 9.0+ is required for FP8"
)
def
test_te_fp8
(
self
):
"""Test Transformer Engine with FP8"""
result
=
self
.
exec
(
True
)
assert
result
[
0
]
<
0.455
and
result
[
1
]
>
0.79
result
=
self
.
exec
(
False
,
None
)
assert
result
[
0
]
<
0.505
and
result
[
1
]
>
0.755
@
unittest
.
skipIf
(
not
is_fp8_supported
(),
"Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
result
=
self
.
exec
(
True
,
"DelayedScaling"
)
assert
result
[
0
]
<
0.505
and
result
[
1
]
>
0.755
@
unittest
.
skipIf
(
not
is_mxfp8_supported
(),
"Device compute capability 10.0+ is required for MXFP8"
)
def
test_te_mxfp8
(
self
):
"""Test Transformer Engine with MXFP8"""
result
=
self
.
exec
(
True
,
"MXFP8BlockScaling"
)
assert
result
[
0
]
<
0.505
and
result
[
1
]
>
0.754
if
__name__
==
"__main__"
:
...
...
examples/jax/encoder/test_single_gpu_encoder.py
View file @
a207db1d
...
...
@@ -16,10 +16,11 @@ from datasets import load_dataset
from
flax
import
linen
as
nn
from
flax.training
import
train_state
from
common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
from
common
import
is_bf16_supported
PARAMS_KEY
=
"params"
DROPOUT_KEY
=
"dropout"
...
...
@@ -59,7 +60,7 @@ class Net(nn.Module):
return
x
@
partial
(
jax
.
jit
)
@
jax
.
jit
def
train_step
(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
):
"""Computes gradients, loss and accuracy for a single batch."""
...
...
@@ -195,9 +196,8 @@ def get_datasets(max_seq_len):
def
check_fp8
(
state
,
var_collect
,
inputs
,
masks
,
labels
):
"Check if model includes FP8."
rngs
=
{
DROPOUT_KEY
:
jax
.
random
.
PRNGKey
(
0
)}
assert
"fp8_"
in
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
)
)
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
train_step
)(
state
,
inputs
,
masks
,
labels
,
var_collect
,
rngs
))
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
train_and_evaluate
(
args
):
...
...
@@ -214,7 +214,12 @@ def train_and_evaluate(args):
mask_shape
=
[
args
.
batch_size
,
1
,
args
.
max_seq_len
,
args
.
max_seq_len
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
):
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
):
encoder
=
Net
(
num_embed
)
# We use nn.Embed, thus inputs need to be in int
inputs
=
jnp
.
zeros
(
input_shape
,
dtype
=
jnp
.
int32
)
...
...
@@ -309,6 +314,12 @@ def encoder_parser(args):
default
=
False
,
help
=
"Use FP8 for inference and training without recalibration"
,
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
return
parser
.
parse_args
(
args
)
...
...
@@ -316,7 +327,8 @@ def encoder_parser(args):
class
TestEncoder
(
unittest
.
TestCase
):
"""Encoder unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -329,10 +341,19 @@ class TestEncoder(unittest.TestCase):
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.45
and
actual
[
1
]
>
0.79
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
def
test_te_fp8
(
self
):
"""Test Transformer Engine with FP8"""
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.455
and
actual
[
1
]
>
0.79
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8
(
self
):
"""Test Transformer Engine with MXFP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
assert
actual
[
0
]
<
0.455
and
actual
[
1
]
>
0.79
...
...
examples/jax/mnist/test_single_gpu_mnist.py
View file @
a207db1d
...
...
@@ -5,6 +5,8 @@
import
argparse
import
unittest
from
functools
import
partial
import
sys
from
pathlib
import
Path
import
jax
import
jax.numpy
as
jnp
...
...
@@ -16,6 +18,11 @@ from flax.training import train_state
import
transformer_engine.jax
as
te
import
transformer_engine.jax.flax
as
te_flax
from
transformer_engine.jax.quantize
import
is_fp8_available
,
ScalingMode
DIR
=
str
(
Path
(
__file__
).
resolve
().
parents
[
1
])
sys
.
path
.
append
(
str
(
DIR
))
from
encoder.common
import
is_bf16_supported
,
get_fp8_recipe_from_name_string
IMAGE_H
=
28
IMAGE_W
=
28
...
...
@@ -37,6 +44,7 @@ class Net(nn.Module):
else
:
nn_Dense
=
nn
.
Dense
# dtype is used for param init in TE but computation in Linen.nn
dtype
=
jnp
.
float32
if
self
.
use_te
else
jnp
.
bfloat16
x
=
nn
.
Conv
(
features
=
32
,
kernel_size
=
(
3
,
3
),
strides
=
1
,
dtype
=
jnp
.
bfloat16
)(
x
)
...
...
@@ -50,8 +58,8 @@ class Net(nn.Module):
x
=
nn_Dense
(
features
=
128
,
dtype
=
dtype
)(
x
)
x
=
nn
.
relu
(
x
)
x
=
nn
.
Dropout
(
rate
=
0.5
)(
x
,
deterministic
=
disable_dropout
)
x
=
nn_Dense
(
features
=
16
,
dtype
=
dtype
)(
x
)
x
=
nn_Dense
(
features
=
10
,
dtype
=
dtype
)(
x
)
x
=
nn_Dense
(
features
=
32
,
dtype
=
dtype
)(
x
)
x
=
nn_Dense
(
features
=
32
,
dtype
=
dtype
)(
x
)
assert
x
.
dtype
==
jnp
.
bfloat16
return
x
...
...
@@ -62,7 +70,7 @@ def apply_model(state, images, labels, var_collect, rngs=None):
def
loss_fn
(
var_collect
,
disable_dropout
=
False
):
logits
=
state
.
apply_fn
(
var_collect
,
images
,
disable_dropout
,
rngs
=
rngs
)
one_hot
=
jax
.
nn
.
one_hot
(
labels
,
10
)
one_hot
=
jax
.
nn
.
one_hot
(
labels
,
32
)
loss
=
jnp
.
mean
(
optax
.
softmax_cross_entropy
(
logits
=
logits
,
labels
=
one_hot
))
return
loss
,
logits
...
...
@@ -153,7 +161,7 @@ def get_datasets():
def
check_fp8
(
state
,
var_collect
,
input_shape
,
label_shape
):
"Check if model includes FP8."
assert
"f8_"
in
str
(
func_jaxpr
=
str
(
jax
.
make_jaxpr
(
apply_model
)(
state
,
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
),
...
...
@@ -161,6 +169,7 @@ def check_fp8(state, var_collect, input_shape, label_shape):
var_collect
,
)
)
assert
"f8_e5m2"
in
func_jaxpr
or
"f8_e4m3"
in
func_jaxpr
def
train_and_evaluate
(
args
):
...
...
@@ -179,7 +188,12 @@ def train_and_evaluate(args):
input_shape
=
[
args
.
batch_size
,
IMAGE_H
,
IMAGE_W
,
IMAGE_C
]
label_shape
=
[
args
.
batch_size
]
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
):
if
args
.
use_fp8
:
fp8_recipe
=
get_fp8_recipe_from_name_string
(
args
.
fp8_recipe
)
else
:
fp8_recipe
=
None
with
te
.
fp8_autocast
(
enabled
=
args
.
use_fp8
,
fp8_recipe
=
fp8_recipe
):
cnn
=
Net
(
args
.
use_te
)
var_collect
=
cnn
.
init
(
init_rngs
,
jnp
.
empty
(
input_shape
,
dtype
=
jnp
.
bfloat16
))
tx
=
optax
.
sgd
(
args
.
lr
,
args
.
momentum
)
...
...
@@ -276,6 +290,12 @@ def mnist_parser(args):
"It also enables Transformer Engine implicitly."
),
)
parser
.
add_argument
(
"--fp8-recipe"
,
action
=
"store_true"
,
default
=
"DelayedScaling"
,
help
=
"Use FP8 recipe (default: DelayedScaling)"
,
)
parser
.
add_argument
(
"--use-te"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use Transformer Engine"
)
...
...
@@ -286,7 +306,8 @@ def mnist_parser(args):
class
TestMNIST
(
unittest
.
TestCase
):
"""MNIST unittests"""
gpu_has_fp8
,
reason
=
te
.
fp8
.
is_fp8_available
()
is_fp8_supported
,
fp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -298,13 +319,14 @@ class TestMNIST(unittest.TestCase):
"""Check If loss and accuracy match target"""
desired_traing_loss
=
0.055
desired_traing_accuracy
=
0.98
desired_test_loss
=
0.04
desired_test_loss
=
0.04
5
desired_test_accuracy
=
0.098
assert
actual
[
0
]
<
desired_traing_loss
assert
actual
[
1
]
>
desired_traing_accuracy
assert
actual
[
2
]
<
desired_test_loss
assert
actual
[
3
]
>
desired_test_accuracy
@
unittest
.
skipIf
(
not
is_bf16_supported
(),
"Device compute capability 8.0+ is required for BF16"
)
def
test_te_bf16
(
self
):
"""Test Transformer Engine with BF16"""
self
.
args
.
use_te
=
True
...
...
@@ -312,10 +334,19 @@ class TestMNIST(unittest.TestCase):
actual
=
train_and_evaluate
(
self
.
args
)
self
.
verify
(
actual
)
@
unittest
.
skipIf
(
not
gpu_has_fp8
,
reason
)
def
test_te_fp8
(
self
):
"""Test Transformer Engine with FP8"""
@
unittest
.
skipIf
(
not
is_fp8_supported
,
fp8_reason
)
def
test_te_delayed_scaling_fp8
(
self
):
"""Test Transformer Engine with DelayedScaling FP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"DelayedScaling"
actual
=
train_and_evaluate
(
self
.
args
)
self
.
verify
(
actual
)
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
mxfp8_reason
)
def
test_te_mxfp8
(
self
):
"""Test Transformer Engine with MXFP8"""
self
.
args
.
use_fp8
=
True
self
.
args
.
fp8_recipe
=
"MXFP8BlockScaling"
actual
=
train_and_evaluate
(
self
.
args
)
self
.
verify
(
actual
)
...
...
qa/L0_jax_unittest/test.sh
View file @
a207db1d
...
...
@@ -20,16 +20,15 @@ pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3
install
pytest
==
8.2.1
||
error_exit
"Failed to install pytest"
:
${
TE_PATH
:
=/opt/transformerengine
}
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax
-k
'not distributed'
--ignore
=
$TE_PATH
/tests/jax/test_
praxis_lay
er
s
.py
||
test_fail
"test
_praxis_layers.py
"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax
-k
'not distributed'
--ignore
=
$TE_PATH
/tests/jax/test_
help
er.py
||
test_fail
"test
s/jax/*not_distributed_*
"
# Test without custom calls
NVTE_CUSTOM_CALLS_RE
=
""
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax/test_custom_call_compute.py
||
test_fail
"test_custom_call_compute.py"
NVTE_CUSTOM_CALLS_RE
=
""
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax/test_custom_call_compute.py
||
test_fail
"test_custom_call_compute.py
without TE custom calls
"
pip3
install
-r
$TE_PATH
/examples/jax/mnist/requirements.txt
||
error_exit
"Failed to install mnist requirements"
pip3
install
-r
$TE_PATH
/examples/jax/encoder/requirements.txt
||
error_exit
"Failed to install encoder requirements"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/mnist
||
test_fail
"test_mnist.py"
pip3
install
-r
$TE_PATH
/examples/jax/encoder/requirements.txt
||
error_exit
"Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
python3
-m
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
||
test_fail
"test_single_gpu_encoder.py"
...
...
qa/L0_pytorch_unittest/test.sh
View file @
a207db1d
...
...
@@ -38,7 +38,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_parallel_cross_entropy.py
||
test_fail
"test_parallel_cross_entropy.py"
python3
-m
pytest
-v
-s
$TE_PATH
/tests/pytorch/test_cpu_offloading.py
||
test_fail
"test_cpu_offloading.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_fused_attn.py
||
test_fail
"test_fused_attn.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_
paged_attn
.py
||
test_fail
"test_
paged_attn
.py"
NVTE_DEBUG
=
1
NVTE_DEBUG_LEVEL
=
1 python3
-m
pytest
-o
log_cli
=
true
--log-cli-level
=
INFO
-v
-s
$TE_PATH
/tests/pytorch/fused_attn/test_
kv_cache
.py
||
test_fail
"test_
kv_cache
.py"
if
[
"
$RET
"
-ne
0
]
;
then
echo
"Error in the following test cases:
$FAILED_CASES
"
...
...
qa/L2_jax_unittest/test.sh
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set
-xe
pip
install
"nltk>=3.8.2"
pip
install
pytest
==
8.2.1
:
${
TE_PATH
:
=/opt/transformerengine
}
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax
-k
'not distributed'
--ignore
=
$TE_PATH
/tests/jax/test_praxis_layers.py
# Test without custom calls
NVTE_JAX_UNITTEST_LEVEL
=
"L2"
NVTE_CUSTOM_CALLS_RE
=
""
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/tests/jax/test_custom_call_compute.py
pip
install
-r
$TE_PATH
/examples/jax/mnist/requirements.txt
pip
install
-r
$TE_PATH
/examples/jax/encoder/requirements.txt
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export
XLA_FLAGS
=
"
${
XLA_FLAGS
}
--xla_gpu_deterministic_ops"
pytest
-c
$TE_PATH
/tests/jax/pytest.ini
-v
$TE_PATH
/examples/jax/encoder/test_single_gpu_encoder.py
tests/jax/distributed_test_base.py
View file @
a207db1d
...
...
@@ -82,7 +82,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'i32[1024]{0}',
'bf16[1024,1024]{0}'
"""
match
=
re
.
search
(
r
"(i|f)(\d+).*\[([0-9,]*)\]"
,
t
)
match
=
re
.
search
(
r
"(i|f
|u
)(\d+).*\[([0-9,]*)\]"
,
t
)
_
,
bits_of_type
,
shape
=
match
.
groups
()
bytes_of_type
=
int
(
bits_of_type
)
//
8
if
shape
==
""
:
...
...
tests/jax/test_custom_call_compute.py
View file @
a207db1d
...
...
@@ -2,31 +2,40 @@
#
# See LICENSE for license information.
from
contextlib
import
nullcontext
from
typing
import
Callable
,
List
,
Sequence
,
Union
import
os
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
pytest
from
jax
import
jit
,
value_and_grad
from
flax
import
linen
as
nn
from
utils
import
assert_allclose
,
assert_tree_like_allclose
from
transformer_engine.jax.dot
import
type_safe_dot_general
,
dequantize
,
quantize
from
transformer_engine.jax.fp8
import
FP8MetaPackage
,
FP8Helper
,
is_fp8_available
from
transformer_engine.jax.layernorm
import
layernorm
,
layernorm_fp8_dot
from
transformer_engine.jax.layernorm_mlp
import
activation_lu
,
fused_layernorm_fp8_mlp
from
transformer_engine.jax.cpp_extensions.activation
import
_jax_act_lu
from
transformer_engine.jax.cpp_extensions.transpose
import
(
_jax_transpose
,
_jax_cast_transpose
,
_jax_dbias_cast_transpose
,
from
functools
import
reduce
import
operator
from
utils
import
(
assert_allclose
,
assert_tree_like_allclose
,
pytest_parametrize_wrapper
,
)
from
transformer_engine.jax.layernorm
import
layernorm
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
from
transformer_engine.jax.cpp_extensions.activation
import
_jax_act_lu
,
_jax_quantize_dact_dbias
from
transformer_engine.jax.cpp_extensions.normalization
import
_jax_layernorm
,
_jax_rmsnorm
from
transformer_engine.jax.cpp_extensions.quantization
import
(
_jax_quantize
,
_jax_quantize_dbias
,
)
from
transformer_engine.jax.cpp_extensions.quantization
import
_jax_cast_fp8
from
transformer_engine.jax
import
cpp_extensions
as
tex
from
transformer_engine.jax.quantize
import
(
DelayedScaleQuantizer
,
ScaledTensor
,
ScalingMode
,
QuantizerFactory
,
QuantizeAxis
,
)
from
transformer_engine.jax.quantize
import
helper
from
transformer_engine.jax.activation
import
activation
from
transformer_engine.jax.dense
import
dense
,
grouped_dense
from
transformer_engine.jax.layernorm_dense
import
layernorm_dense
from
transformer_engine.jax.quantize
import
ScaledTensor1x
,
ScaledTensor2x
GEMM_CASES
=
[
(
256
,
256
,
512
),
...
...
@@ -36,844 +45,1195 @@ GEMM_CASES = [
(
2048
,
1024
,
1024
),
]
FP8_COMPUTE_TYPE
=
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
]
LN_CASES
=
[(
512
,
1024
)]
LN_CASES
=
[(
256
,
128
),
(
128
,
256
)]
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
is_fp8_supported
,
reason
=
is_fp8_available
()
class
TestFP8Dot
:
@
staticmethod
def
_generate_fp8_meta
():
fp8_dtype_list
=
[
FP8Helper
.
FWD_DTYPE
,
FP8Helper
.
FWD_DTYPE
,
FP8Helper
.
BWD_DTYPE
]
amax_list
=
[
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
]
scale_list
=
[
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
]
return
fp8_dtype_list
,
amax_list
,
scale_list
is_fp8_supported
,
reason
=
helper
.
is_fp8_available
()
is_mxfp8_supported
,
reason
=
helper
.
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
supported_scaling_modes
=
[]
""" Find supported scaling modes"""
if
is_fp8_supported
:
supported_scaling_modes
.
append
(
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
)
if
is_mxfp8_supported
:
supported_scaling_modes
.
append
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
def
is_shape_supported_by_mxfp8
(
input_shape
):
try
:
if
isinstance
(
input_shape
,
type
(
pytest
.
param
(
0
))):
input_shape
=
input_shape
.
values
[
0
]
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
get_scale_shape_2x
(
input_shape
)
return
True
except
:
# get_scale_shapes will raise an exception if the shape is not supported
return
False
def
assert_bitwise_scaled_tensors
(
a
:
ScaledTensor
,
b
:
ScaledTensor
):
if
isinstance
(
a
,
ScaledTensor1x
)
and
isinstance
(
b
,
ScaledTensor1x
):
assert_allclose
(
a
.
data
,
b
.
data
)
assert_allclose
(
a
.
scale_inv
.
astype
(
jnp
.
uint8
),
b
.
scale_inv
.
astype
(
jnp
.
uint8
))
elif
isinstance
(
a
,
ScaledTensor2x
)
and
isinstance
(
b
,
ScaledTensor2x
):
assert_bitwise_scaled_tensors
(
a
.
rowwise_tensor
,
b
.
rowwise_tensor
)
assert_bitwise_scaled_tensors
(
a
.
colwise_tensor
,
b
.
colwise_tensor
)
else
:
pytest
.
fail
(
"Unsupported input types"
)
def
assert_dequantized_scaled_tensor
(
a
:
ScaledTensor
,
b
:
jnp
.
ndarray
):
if
isinstance
(
a
,
ScaledTensor1x
):
if
a
.
layout
==
"T"
:
b_transpose
=
jnp
.
transpose
(
b
,
(
-
1
,
*
range
(
b
.
ndim
-
1
)))
assert_allclose
(
a
.
dequantize
(),
b_transpose
,
dtype
=
a
.
data
.
dtype
)
else
:
assert_allclose
(
a
.
dequantize
(),
b
,
dtype
=
a
.
data
.
dtype
)
elif
isinstance
(
a
,
ScaledTensor2x
):
assert_dequantized_scaled_tensor
(
a
.
get_rowwise_tensor
(),
b
)
assert_dequantized_scaled_tensor
(
a
.
get_colwise_tensor
(),
b
)
else
:
pytest
.
fail
(
"a must be a ScaledTensor object"
)
ALL_ACTIVATION_SHAPES
=
[(
32
,
64
),
(
16
,
128
,
256
)]
ALL_ACTIVATION_TYPES
=
[
(
"gelu"
,),
(
"gelu"
,
"linear"
),
(
"silu"
,),
(
"silu"
,
"linear"
),
(
"relu"
,),
(
"relu"
,
"linear"
),
(
"quick_gelu"
,),
(
"quick_gelu"
,
"linear"
),
(
"squared_relu"
,),
(
"squared_relu"
,
"linear"
),
]
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_qdq
(
self
):
FP8_E4M3_MAX
=
(
jnp
.
finfo
(
jnp
.
float8_e4m3fn
).
max
).
astype
(
jnp
.
float32
)
x
=
jnp
.
asarray
([[
-
1
,
0.1
],
[
2
,
3
]],
jnp
.
float32
)
amax
=
jnp
.
max
(
jnp
.
abs
(
x
)).
reshape
(
1
)
scale
=
jnp
.
asarray
(
FP8_E4M3_MAX
/
amax
,
jnp
.
float32
).
reshape
(
1
)
scale_inv
=
(
1
/
scale
).
reshape
(
1
)
ACTIVATION_TYPES
=
{
"L0"
:
[
(
"gelu"
,),
(
"gelu"
,
"linear"
),
],
"L2"
:
ALL_ACTIVATION_TYPES
,
}
y
,
_
=
quantize
(
x
,
q_dtype
=
jnp
.
float8_e4m3fn
,
scale
=
scale
)
z
=
dequantize
(
y
,
dq_dtype
=
jnp
.
float32
,
scale_inv
=
scale_inv
)
assert_allclose
(
z
,
x
,
dtype
=
jnp
.
float8_e4m3fn
)
class
TestActivation
:
def
ref_act
(
self
,
x
,
activation_type
):
return
_jax_act_lu
(
x
,
activation_type
)
def
value_n_grad_ref_func
(
self
,
x
,
activation_type
):
jitted_reference
=
jit
(
value_and_grad
(
lambda
out
:
jnp
.
mean
(
self
.
ref_act
(
out
,
activation_type
)),
(
0
,))
)
return
jitted_reference
(
x
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
GEMM_CASES
)
def
test_forward_bf16
(
self
,
m
,
n
,
k
):
def
primitive_func
(
self
,
inputs
,
activation_type
,
quantizer
):
out
=
activation
(
inputs
,
activation_type
=
activation_type
,
quantizer
=
quantizer
)
return
jnp
.
mean
(
out
)
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
(
ALL_ACTIVATION_TYPES
# Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
),
)
def
test_act_grad
(
self
,
shape
,
activation_type
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
a
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
b
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
n
),
jnp
.
bfloat16
)
x
=
jax
.
random
.
uniform
(
key
,
shape
,
jnp
.
float32
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
primitive_out
=
type_safe_dot_general
(
a
,
b
)
ref_out
=
jnp
.
dot
(
a
,
b
)
value_n_grad_primitive_func
=
jit
(
value_and_grad
(
self
.
primitive_func
,
(
0
,)),
static_argnums
=
(
1
,)
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
prim_out
,
(
prim_grad
,)
=
value_n_grad_primitive_func
(
x
,
activation_type
,
None
)
ref_out
,
(
ref_grad
,)
=
self
.
value_n_grad_ref_func
(
x
,
activation_type
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
GEMM_CASES
)
def
test_forward_fp8_randint
(
self
,
m
,
n
,
k
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
x
.
dtype
)
assert_allclose
(
prim_grad
,
ref_grad
,
dtype
=
x
.
dtype
)
dtype
=
jnp
.
bfloat16
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
def
test_act_grad_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
self
.
activation_type
=
activation_type
# TODO(rewang): add float random test
min_val
,
max_val
=
-
8
,
8
a
=
jax
.
random
.
randint
(
subkeys
[
0
],
(
m
,
k
),
min_val
,
max_val
).
astype
(
dtype
)
b
=
jax
.
random
.
randint
(
subkeys
[
1
],
(
k
,
n
),
min_val
,
max_val
).
astype
(
dtype
)
value_n_grad_primitive_func
=
jit
(
value_and_grad
(
self
.
primitive_func
,
(
0
,)),
static_argnums
=
(
1
,)
)
_
,
amax_list
,
scale_list
=
TestFP8Dot
.
_generate_fp8_meta
()
fp8_meta_pkg
=
FP8MetaPackage
(
amax_list
[
0
],
scale_list
[
0
],
amax_list
[
1
],
scale_list
[
1
],
amax_list
[
2
],
scale_list
[
2
],
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
q_dtype
=
output_type
,
q_axis
=
QuantizeAxis
.
ROWWISE
,
)
primitive_out
=
type_safe_dot_general
(
a
,
b
,
fp8_meta_pkg
)
ref_out
=
jnp
.
dot
(
a
,
b
)
ref_out
=
ref_out
.
astype
(
jnp
.
float32
)
primitive_out
=
primitive_out
.
astype
(
jnp
.
float32
)
prim_out
,
(
prim_grad
,)
=
value_n_grad_primitive_func
(
x
,
activation_type
,
quantizer
)
ref_out
,
(
ref_grad
,)
=
self
.
value_n_grad_ref_func
(
x
,
activation_type
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
FP8Helper
.
FWD_DTYPE
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
output_type
)
assert_allclose
(
prim_grad
,
ref_grad
,
dtype
=
output_type
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
GEMM_CASES
)
def
test_grad_bf16
(
self
,
m
,
n
,
k
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
a
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
b
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
n
),
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_act_forward_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_axis
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
self
.
activation_type
=
activation_type
def
primitive_func
(
x
,
y
):
primitive_out
=
type_safe_dot_general
(
x
,
y
)
return
jnp
.
mean
(
primitive_out
)
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
q_dtype
=
output_type
,
q_axis
=
q_axis
,
)
def
ref_func
(
x
,
y
):
return
jnp
.
mean
(
jnp
.
dot
(
x
,
y
)
)
te_output
=
tex
.
act_lu
(
x
,
activation_type
,
te_quantizer
)
jax_output
=
_jax_act_lu
(
x
,
activation_type
,
jax_quantizer
)
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
)
)
assert_bitwise_scaled_tensors
(
te_output
,
jax_output
)
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
))
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"shape"
,
[(
128
,
128
)])
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_act_forward_with_block_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_axis
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
self
.
activation_type
=
activation_type
primitive_out
,
(
primitive_a_grad
,
primitive_b_grad
)
=
value_n_grad_primitive_func
(
a
,
b
)
ref_out
,
(
ref_a_grad
,
ref_b_grad
)
=
value_n_grad_ref_func
(
a
,
b
)
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
NVTE_MXFP8_1D_SCALING
,
q_dtype
=
output_type
,
q_axis
=
q_axis
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_a_grad
,
ref_a_grad
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_b_grad
,
ref_b_grad
,
dtype
=
jnp
.
bfloat16
)
output
=
tex
.
act_lu
(
x
,
activation_type
,
quantizer
)
ref_out
=
self
.
ref_act
(
x
,
activation_type
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
GEMM_CASES
)
def
test_grad_fp8_dot
(
self
,
m
,
n
,
k
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
assert_dequantized_scaled_tensor
(
output
,
ref_out
)
a
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
)).
astype
(
jnp
.
bfloat16
)
b
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
n
)).
astype
(
jnp
.
bfloat16
)
_
,
amax_list
,
scale_list
=
TestFP8Dot
.
_generate_fp8_meta
()
NORM_OUTPUT_DTYPES
=
{
"L0"
:
[
jnp
.
float8_e4m3fn
],
"L2"
:
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
],
}
def
primitive_func
(
x
,
y
,
amax_list
,
scale_list
):
fp8_meta_pkg
=
FP8MetaPackage
(
amax_list
[
0
],
scale_list
[
0
],
amax_list
[
1
],
scale_list
[
1
],
amax_list
[
2
],
scale_list
[
2
],
)
primitive_out
=
type_safe_dot_general
(
x
,
y
,
fp8_meta_pkg
)
return
jnp
.
mean
(
primitive_out
)
def
ref_func
(
x
,
y
):
return
jnp
.
mean
(
jnp
.
dot
(
x
,
y
))
@
pytest_parametrize_wrapper
(
"n, hidden"
,
LN_CASES
)
@
pytest_parametrize_wrapper
(
"inp_dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest_parametrize_wrapper
(
"zero_centered_gamma"
,
[
pytest
.
param
(
True
,
id
=
"zero_centered"
),
pytest
.
param
(
False
,
id
=
"no_zero_centered"
),
],
)
@
pytest_parametrize_wrapper
(
"epsilon"
,
[
1e-2
,
1e-6
])
class
TestNorm
:
"""
Test transformer_engine.jax.layernorm APIs
"""
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
,
3
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
))
def
_test_norm_grad
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
):
def
compute_loss
(
x
):
# Higher precision to compute the loss
x_
=
x
.
astype
(
jnp
.
float32
)
return
jnp
.
mean
(
jnp
.
square
(
x_
)).
astype
(
x
.
dtype
)
def
reference_func
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
):
if
norm_type
==
"rmsnorm"
:
ln_out
,
_
=
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
eps
,
quantizer
)
else
:
ln_out
,
_
,
_
=
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
eps
,
quantizer
)
# if isinstance(ln_out, ScaledTensor):
# ln_out = ln_out.dequantize()
return
ln_out
ref_out
,
(
ref_a_grad
,
ref_b_grad
)
=
value_n_grad_ref_func
(
a
,
b
)
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
3
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
(
n
,
hidden
),
jnp
.
float32
,
-
1
,
1
)
x
=
x
.
astype
(
inp_dtype
)
gamma_range
=
(
-
1
,
1
)
if
zero_centered_gamma
else
(
0
,
2
)
gamma
=
jax
.
random
.
uniform
(
subkeys
[
1
],
(
hidden
,),
jnp
.
float32
,
*
gamma_range
)
gamma
=
jnp
.
asarray
(
gamma
,
inp_dtype
)
if
norm_type
==
"layernorm"
:
beta
=
jax
.
random
.
uniform
(
subkeys
[
2
],
(
hidden
,),
jnp
.
float32
,
-
1
,
1
)
beta
=
jnp
.
asarray
(
beta
,
inp_dtype
)
else
:
beta
=
None
for
_
in
range
(
3
):
primitive_out
,
(
primitive_a_grad
,
primitive_b_grad
,
amax_list
,
scale_list
)
=
(
value_n_grad_primitive_func
(
a
,
b
,
amax_list
,
scale_list
)
jitted_reference
=
jit
(
value_and_grad
(
lambda
x
,
gamma
,
beta
:
compute_loss
(
reference_func
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
epsilon
,
quantizer
=
None
)
),
(
0
,
1
,
2
),
)
)
jitted_primitive
=
jit
(
value_and_grad
(
lambda
x
,
gamma
,
beta
:
compute_loss
(
layernorm
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
epsilon
,
quantizer
)
),
(
0
,
1
,
2
),
)
)
reference_out
,
(
reference_dx
,
reference_dgamma
,
reference_dbeta
)
=
jitted_reference
(
x
,
gamma
,
beta
)
primitive_out
,
(
primitive_dx
,
primitive_dgamma
,
primitive_dbeta
)
=
jitted_primitive
(
x
,
gamma
,
beta
)
out_dtype
=
inp_dtype
if
quantizer
is
None
else
quantizer
.
q_dtype
assert_allclose
(
primitive_out
,
reference_out
,
dtype
=
out_dtype
)
assert_allclose
(
primitive_dx
,
reference_dx
,
dtype
=
out_dtype
)
assert_allclose
(
primitive_dgamma
,
reference_dgamma
,
dtype
=
out_dtype
)
if
beta
is
not
None
:
assert_allclose
(
primitive_dbeta
,
reference_dbeta
,
dtype
=
out_dtype
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
FP8Helper
.
FWD_DTYPE
)
assert_allclose
(
primitive_a_grad
,
ref_a_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
assert_allclose
(
primitive_b_grad
,
ref_b_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
def
test_norm_grad
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
self
.
_test_norm_grad
(
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
=
None
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
256
,
128
,
512
),
(
16384
,
1024
,
2816
),
(
16384
,
2816
,
1024
),
(
16384
,
1024
,
1024
)]
)
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[
(
"gelu"
,),
(
"gelu"
,
"linear"
),
(
"silu"
,),
(
"silu"
,
"linear"
),
(
"relu"
,),
(
"relu"
,
"linear"
),
(
"quick_gelu"
,),
(
"quick_gelu"
,
"linear"
),
(
"squared_relu"
,),
(
"squared_relu"
,
"linear"
),
],
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_grad_fused_layernorm_fp8_mlp
(
self
,
m
,
n
,
k
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
use_bias
:
bool
# No Norm FWD E5M2 in TE backend
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_norm_grad_with_delayed_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_axis
):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
q_dtype
=
out_dtype
,
q_axis
=
q_axis
)
self
.
_test_norm_grad
(
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
)
def
_test_norm_forward
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
scaling_mode
,
q_axis
,
):
"""N/a"""
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
6
)
subkeys
=
jax
.
random
.
split
(
key
,
3
)
a
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
k1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
len
(
activation_type
),
n
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
k
)
k2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
n
,
k
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
n
)
s
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
k
,),
jnp
.
bfloat16
)
if
use_bias
:
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
),
n
),
jnp
.
bfloat16
)
b2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
k
,),
jnp
.
bfloat16
)
else
:
b1
=
None
b2
=
None
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
(
n
,
hidden
),
inp_dtype
,
-
1
,
1
)
x
=
jnp
.
asarray
(
x
,
inp_dtype
)
gamma_range
=
(
-
1
,
1
)
if
zero_centered_gamma
else
(
0
,
2
)
gamma
=
jax
.
random
.
uniform
(
subkeys
[
1
],
(
hidden
,),
jnp
.
float32
,
*
gamma_range
)
gamma
=
jnp
.
asarray
(
gamma
,
inp_dtype
)
def
primitive_func
(
x
,
ln_s
,
y
,
z
,
w
,
v
,
amax_list_1
,
amax_list_2
,
scale_list_1
,
scale_list_2
):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
fp8_meta_pkg_1
=
FP8MetaPackage
(
amax_list_1
[
0
],
scale_list_1
[
0
],
amax_list_1
[
1
],
scale_list_1
[
1
],
amax_list_1
[
2
],
scale_list_1
[
2
],
quantizer
,
ref_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
scaling_mode
=
scaling_mode
,
q_dtype
=
out_dtype
,
q_axis
=
q_axis
)
if
norm_type
==
"layernorm"
:
beta
=
jax
.
random
.
uniform
(
subkeys
[
2
],
(
hidden
,),
jnp
.
float32
,
-
1
,
1
)
beta
=
jnp
.
asarray
(
beta
,
inp_dtype
)
output
,
mu
,
rsigma
=
tex
.
layernorm_fwd
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
=
quantizer
)
fp8_meta_pkg_2
=
FP8MetaPackage
(
amax_list_2
[
0
],
scale_list_2
[
0
],
amax_list_2
[
1
],
scale_list_2
[
1
],
amax_list_2
[
2
],
scale_list_2
[
2
],
ref_out
,
ref_mu
,
ref_rsigma
=
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
=
ref_quantizer
)
return
jnp
.
mean
(
fused_layernorm_fp8_mlp
(
x
,
ln_s
,
None
,
[
y
,
z
],
[
w
,
v
],
[
fp8_meta_pkg_1
,
fp8_meta_pkg_2
],
"rmsnorm"
,
activation_type
=
activation_type
,
use_bias
=
use_bias
,
)
else
:
output
,
rsigma
=
tex
.
rmsnorm_fwd
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
=
quantizer
)
def
layernorm_fp8_mlp_ref
(
x
:
jnp
.
ndarray
,
ln_scale
:
jnp
.
ndarray
,
kernel_1
:
jnp
.
ndarray
,
kernel_2
:
jnp
.
ndarray
,
bias_1
:
jnp
.
ndarray
,
bias_2
:
jnp
.
ndarray
,
amax_list_1
:
List
[
jnp
.
ndarray
],
amax_list_2
:
List
[
jnp
.
ndarray
],
scale_list_1
:
List
[
jnp
.
ndarray
],
scale_list_2
:
List
[
jnp
.
ndarray
],
)
->
jnp
.
ndarray
:
x
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
mean2
=
jnp
.
mean
(
jax
.
lax
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)
y
=
jnp
.
asarray
(
x
*
jax
.
lax
.
rsqrt
(
mean2
+
1e-6
),
jnp
.
bfloat16
)
ln_out
=
y
*
ln_scale
ln_out
=
jnp
.
asarray
(
ln_out
,
jnp
.
bfloat16
)
fp8_meta_pkg_1
=
FP8MetaPackage
(
amax_list_1
[
0
],
scale_list_1
[
0
],
amax_list_1
[
1
],
scale_list_1
[
1
],
amax_list_1
[
2
],
scale_list_1
[
2
],
ref_out
,
ref_rsigma
=
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
=
ref_quantizer
)
linear_1_out
=
type_safe_dot_general
(
ln_out
,
kernel_1
,
fp8_meta_pkg_1
,
((
1
,),
(
0
,)))
ref_mu
=
None
if
use_bias
:
bias_1_shape
=
(
1
,)
*
(
linear_1_out
.
ndim
-
bias_1
.
ndim
)
+
bias_1
.
shape
linear_1_out
+=
jnp
.
reshape
(
bias_1
,
bias_1_shape
)
assert_bitwise_scaled_tensors
(
output
,
ref_out
)
assert_allclose
(
rsigma
,
ref_rsigma
,
dtype
=
inp_dtype
)
if
norm_type
==
"layernorm"
:
assert_allclose
(
mu
,
ref_mu
,
dtype
=
inp_dtype
)
x
=
_jax_act_lu
(
linear_1_out
,
activation_type
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
# No Norm FWD E5M2 in TE backend
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_norm_forward_with_delayed_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_axis
):
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
self
.
_test_norm_forward
(
n
=
n
,
hidden
=
hidden
,
norm_type
=
norm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
inp_dtype
=
inp_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
q_axis
=
q_axis
,
)
fp8_meta_pkg_2
=
FP8MetaPackage
(
amax_list_2
[
0
],
scale_list_2
[
0
],
amax_list_2
[
1
],
scale_list_2
[
1
],
amax_list_2
[
2
],
scale_list_2
[
2
],
)
output
=
type_safe_dot_general
(
x
,
kernel_2
,
fp8_meta_pkg_2
,
((
1
,),
(
0
,)))
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
def
test_norm_forward_with_block_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
):
self
.
_test_norm_forward
(
n
=
n
,
hidden
=
hidden
,
norm_type
=
norm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
inp_dtype
=
inp_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_MXFP8_1D_SCALING
,
q_axis
=
QuantizeAxis
.
ROWWISE_COLWISE
,
)
if
use_bias
:
bias_2_shape
=
(
1
,)
*
(
output
.
ndim
-
bias_2
.
ndim
)
+
bias_2
.
shape
output
+=
jnp
.
reshape
(
bias_2
,
bias_2_shape
)
return
output
QUANTIZE_OUTPUT_DTYPES
=
{
"L0"
:
[
jnp
.
float8_e4m3fn
],
"L2"
:
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
],
}
def
ref_func
(
x
,
ln_s
,
y
,
z
,
w
,
v
,
amax_list_1
,
amax_list_2
,
scale_list_1
,
scale_list_2
):
return
jnp
.
mean
(
layernorm_fp8_mlp_ref
(
x
,
ln_s
,
y
,
z
,
w
,
v
,
amax_list_1
,
amax_list_2
,
scale_list_1
,
scale_list_2
)
)
ALL_QUANTIZE_TEST_SHAPES
=
[
(
128
,
128
),
(
4
,
256
,
512
),
]
value_n_grad_primitive_func
=
jit
(
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
))
)
value_n_grad_ref_func
=
jit
(
value_and_grad
(
ref_func
,
(
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
)))
_
,
amax_list_1
,
scale_list_1
=
TestFP8Dot
.
_generate_fp8_meta
()
_
,
amax_list_2
,
scale_list_2
=
TestFP8Dot
.
_generate_fp8_meta
()
ref_amax_list_1
=
amax_list_1
ref_scale_list_1
=
scale_list_1
ref_amax_list_2
=
amax_list_2
ref_scale_list_2
=
scale_list_2
primitive_amax_list_1
=
amax_list_1
primitive_scale_list_1
=
scale_list_1
primitive_amax_list_2
=
amax_list_2
primitive_scale_list_2
=
scale_list_2
primitive_amax_list_1
,
primitive_scale_list_1
,
primitive_amax_list_2
,
primitive_scale_list_2
# Convert str to index as str is not a valid type for JAX JIT
for
_
in
range
(
3
):
ref_out
,
(
ref_a_grad
,
ref_s_grad
,
ref_k1_grad
,
ref_k2_grad
,
ref_b1_grad
,
ref_b2_grad
,
ref_amax_list_1
,
ref_amax_list_2
,
ref_scale_list_1
,
ref_scale_list_2
,
)
=
value_n_grad_ref_func
(
a
,
s
,
k1
,
k2
,
b1
,
b2
,
ref_amax_list_1
,
ref_amax_list_2
,
ref_scale_list_1
,
ref_scale_list_2
,
)
QUANTIZE_TEST_SHAPES
=
{
"L0"
:
[
(
256
,
128
),
(
64
,
16
,
2
,
256
),
],
"L2"
:
ALL_QUANTIZE_TEST_SHAPES
,
}
for
_
in
range
(
3
):
primitive_out
,
(
primitive_a_grad
,
primitive_s_grad
,
primitive_k1_grad
,
primitive_k2_grad
,
primitive_b1_grad
,
primitive_b2_grad
,
primitive_amax_list_1
,
primitive_amax_list_2
,
primitive_scale_list_1
,
primitive_scale_list_2
,
)
=
value_n_grad_primitive_func
(
a
,
s
,
k1
,
k2
,
b1
,
b2
,
primitive_amax_list_1
,
primitive_amax_list_2
,
primitive_scale_list_1
,
primitive_scale_list_2
,
)
QUANTIZATION_INPUT_DTYPE
=
{
"L0"
:
[
jnp
.
bfloat16
],
"L2"
:
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
],
}
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
]
)
class
TestQuantize
:
"""
Purely quantization related tests that will always test on a wider set of types and shapes
"""
def
test_qdq
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_axis
):
key
=
jax
.
random
.
PRNGKey
(
0
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
FP8Helper
.
FWD_DTYPE
)
assert_allclos
e
(
jnp
.
asarray
(
primitive_a_grad
,
np
.
float32
)
,
jnp
.
asarray
(
ref_a_grad
,
np
.
float32
)
,
dtype
=
FP8Helper
.
BWD_DTYPE
,
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling
)
quantizer
=
QuantizerFactory
.
creat
e
(
scaling_mode
=
scaling_mode
,
q_dtype
=
q_dtype
,
q_axis
=
q_axis
,
)
assert_allclose
(
jnp
.
asarray
(
primitive_k1_grad
,
np
.
float32
),
jnp
.
asarray
(
ref_k1_grad
,
np
.
float32
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
x
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
scaled_tensor
=
quantizer
.
quantize
(
x
)
assert_dequantized_scaled_tensor
(
scaled_tensor
,
x
)
def
test_quantize_bitwise
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_axis
):
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
and
not
is_shape_supported_by_mxfp8
(
input_shape
):
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
key
=
jax
.
random
.
PRNGKey
(
0
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
)
assert_allclose
(
jnp
.
asarray
(
primitive_s_grad
,
np
.
float32
),
jnp
.
asarray
(
ref_s_grad
,
np
.
float32
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
jax_output
=
_jax_quantize
(
input
,
quantizer
=
jax_quantizer
)
te_output
=
tex
.
quantize
(
input
,
quantizer
=
te_quantizer
)
assert_bitwise_scaled_tensors
(
jax_output
,
te_output
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
class
TestFusedQuantize
:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_quantize_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
q_axis
):
transpose_axis
=
-
1
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
and
not
is_shape_supported_by_mxfp8
(
input_shape
):
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
key
=
jax
.
random
.
PRNGKey
(
0
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
jax_quantizer
,
te_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
)
assert_allclose
(
jnp
.
asarray
(
primitive_k2_grad
,
np
.
float32
),
jnp
.
asarray
(
ref_k2_grad
,
np
.
float32
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
te_output
,
te_dbias
=
jit
(
lambda
input
:
tex
.
quantize_dbias
(
input
,
quantizer
=
te_quantizer
))(
input
)
if
use_bias
:
assert_allclose
(
jnp
.
asarray
(
primitive_b2_grad
,
np
.
float32
),
jnp
.
asarray
(
ref_b2_grad
,
np
.
float32
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
)
assert_allclose
(
jnp
.
asarray
(
primitive_b1_grad
,
np
.
float32
),
jnp
.
asarray
(
ref_b1_grad
,
np
.
float32
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
jax_output
,
jax_dbias
=
jit
(
lambda
input
:
_jax_quantize_dbias
(
input
,
quantizer
=
jax_quantizer
,
)
)(
input
)
assert_bitwise_scaled_tensors
(
jax_output
,
te_output
)
@
pytest
.
fixture
(
name
=
"random_inputs"
)
def
random_inputs_fixture
(
shape
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
4
)
out
=
jax
.
random
.
uniform
(
subkeys
[
0
],
shape
,
jnp
.
bfloat16
,
5
,
8
)
return
out
assert_allclose
(
jax_dbias
,
te_dbias
)
def
_test_quantize_dact_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
activation_type
,
is_dbias
,
q_axis
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
input_shape
,
in_dtype
,
-
1
,
1
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
dz
=
jax
.
random
.
uniform
(
subkeys
[
1
],
input_shape
,
in_dtype
,
-
1
,
1
)
jax_quantizer
,
te_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
)
is_casted_output
=
te_quantizer
is
not
None
te_output
,
te_dbias
=
jit
(
lambda
dz
,
x
:
tex
.
quantize_dact_dbias
(
dz
,
x
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
quantizer
=
te_quantizer
,
)
)(
dz
,
x
)
jax_output
,
jax_dbias
=
jit
(
lambda
dz
,
x
:
_jax_quantize_dact_dbias
(
dz
,
x
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
quantizer
=
jax_quantizer
,
)
)(
dz
,
x
)
class
TestActivationLu
:
if
is_casted_output
:
assert_bitwise_scaled_tensors
(
jax_output
,
te_output
)
else
:
assert_allclose
(
jax_output
,
te_output
)
if
is_dbias
:
assert_allclose
(
jax_dbias
,
te_dbias
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
def
test_quantize_dact_dbias_no_quantization
(
self
,
in_dtype
,
input_shape
,
activation_type
,
is_dbias
,
):
self
.
_test_quantize_dact_dbias
(
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
out_dtype
=
in_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_NO_SCALING
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
q_axis
=
QuantizeAxis
.
ROWWISE
,
)
def
ref_func
(
self
,
x
,
activation_type
):
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_quantize_dact_dbias_delayed_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_axis
):
self
.
_test_quantize_dact_dbias
(
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
q_axis
=
q_axis
,
)
def
ref_act_lu
(
inputs
):
x
=
_jax_act_lu
(
inputs
,
activation_type
)
return
jnp
.
mean
(
x
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
[
s
for
s
in
ALL_ACTIVATION_SHAPES
if
is_shape_supported_by_mxfp8
(
s
)]
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_quantize_dact_dbias_mxfp8_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_axis
):
if
reduce
(
operator
.
mul
,
input_shape
[:
-
1
])
%
128
!=
0
or
input_shape
[
-
1
]
%
128
!=
0
:
# TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
# If it doesn't, move this check into the quantize_dact_dbias function and revert to JAX
# implementation in the unsupported cases
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by dact MXFP8 kernel in TE currently"
)
ref_act_func
=
jit
(
value_and_grad
(
ref_act_lu
,
(
0
,)))
return
ref_act_func
(
x
)
self
.
_test_quantize_dact_dbias
(
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_MXFP8_1D_SCALING
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
q_axis
=
q_axis
,
)
def
primitive_func
(
self
,
inputs
):
return
jnp
.
mean
(
activation_lu
(
inputs
,
activation_type
=
self
.
activation_type
))
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
1
,
64
),
(
16
,
64
,
1
,
256
)])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[
(
"gelu"
,),
(
"gelu"
,
"linear"
),
(
"silu"
,),
(
"silu"
,
"linear"
),
(
"relu"
,),
(
"relu"
,
"linear"
),
(
"quick_gelu"
,),
(
"quick_gelu"
,
"linear"
),
(
"squared_relu"
,),
(
"squared_relu"
,
"linear"
),
],
)
def
test_activation_lu
(
self
,
random_inputs
,
activation_type
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
self
.
activation_type
=
activation_type
class
TestDense
:
def
_ref_gemm_with_jnp_dot
(
self
,
a
,
b
,
layout
):
if
layout
[
0
]
==
"T"
:
a
=
jnp
.
swapaxes
(
a
,
-
1
,
-
2
)
if
layout
[
1
]
==
"T"
:
b
=
jnp
.
swapaxes
(
b
,
-
1
,
-
2
)
return
jnp
.
dot
(
a
,
b
)
value_n_grad_primitive_func
=
jit
(
value_and_grad
(
self
.
primitive_func
,
(
0
,)))
def
_generate_gemm_input
(
self
,
m
,
n
,
k
,
layout
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
(
m
if
layout
[
0
]
==
"N"
else
k
,
k
if
layout
[
0
]
==
"N"
else
m
),
dtype
=
jnp
.
bfloat16
,
)
/
jnp
.
sqrt
(
k
)
w
=
jax
.
random
.
uniform
(
subkeys
[
1
],
(
k
if
layout
[
1
]
==
"N"
else
n
,
n
if
layout
[
1
]
==
"N"
else
k
),
dtype
=
jnp
.
bfloat16
,
)
/
jnp
.
sqrt
(
n
)
lhs_contracting_dim
=
(
1
,)
if
layout
[
0
]
==
"N"
else
(
0
,)
rhs_contracting_dim
=
(
0
,)
if
layout
[
1
]
==
"N"
else
(
1
,)
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
return
(
x
,
w
,
contracting_dims
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_bf16
(
self
,
m
,
n
,
k
,
layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
prim_out
,
(
prim_grad
,)
=
value_n_grad_primitive_func
(
x
)
ref_out
,
(
ref_grad
,)
=
self
.
ref_func
(
x
,
activation_type
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
x
.
dtype
)
assert_allclose
(
prim_grad
,
ref_grad
,
dtype
=
x
.
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
False
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
=
contracting_dims
,
quantizer_set
=
quantizer_set
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
class
TestActivationLuFP8
(
TestActivationLu
):
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
def
test_dense_grad_bf16
(
self
,
m
,
n
,
k
):
layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
def
prim_func
(
self
,
x
):
amax
=
self
.
amax
scale
=
self
.
scale
scale_inv
=
self
.
scale_inv
activation_type
=
self
.
activation_type
def
primitive_func
(
x
,
w
,
contracting_dims
):
primitive_out
=
dense
(
x
,
w
,
contracting_dims
=
contracting_dims
)
return
jnp
.
mean
(
primitive_out
)
@
jax
.
custom_vjp
def
_prim_func
(
x
,
_x_t
,
_dbias
,
_amax
):
output
=
_prim_func_fwd
(
x
,
_x_t
,
_dbias
,
_amax
)
return
output
def
ref_func
(
x
,
w
,
layout
):
return
jnp
.
mean
(
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
))
def
_prim_func_fwd
(
x
,
_x_t
,
_dbias
,
_amax
):
activation_lu_out
,
_
=
tex
.
act_lu_fp8
(
x
,
amax
,
scale
,
scale_inv
,
FP8Helper
.
FWD_DTYPE
,
activation_type
)
activation_lu_out
=
dequantize
(
activation_lu_out
,
x
.
dtype
,
scale_inv
)
ctx
=
x
return
activation_lu_out
,
ctx
def
_prim_func_bwd
(
ctx
,
g
):
x
=
ctx
if
len
(
self
.
activation_type
)
>
1
:
# gated, no bias
dactivation_lu
,
dactivation_lu_trans
,
amax_out
=
tex
.
dgated_act_lu_cast_transpose
(
g
,
x
,
amax
,
scale
,
scale_inv
,
FP8Helper
.
BWD_DTYPE
,
-
1
,
activation_type
)
dbias
=
jnp
.
empty
(
x
.
shape
[
-
1
],
x
.
dtype
)
else
:
# not gated, with bias
dactivation_lu
,
dactivation_lu_trans
,
dbias
,
amax_out
=
(
tex
.
dact_lu_dbias_cast_transpose
(
g
,
x
,
amax
,
scale
,
scale_inv
,
FP8Helper
.
BWD_DTYPE
,
-
1
,
self
.
activation_type
,
)
)
dactivation_lu
=
dequantize
(
dactivation_lu
,
x
.
dtype
,
scale_inv
)
dactivation_lu_trans
=
dequantize
(
dactivation_lu_trans
,
x
.
dtype
,
scale_inv
)
ctx
=
(
dactivation_lu
,
dactivation_lu_trans
,
dbias
,
amax_out
)
return
ctx
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
))
_prim_func
.
defvjp
(
_prim_func_fwd
,
_prim_func_bwd
)
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
)
)
dx_trans_no_use
=
jnp
.
empty
([
x
.
shape
[
i
]
for
i
in
self
.
transpose_axes
],
dtype
=
x
.
dtype
)
dbias_no_use
=
jnp
.
empty
(
x
.
shape
[
-
1
],
dtype
=
x
.
dtype
)
amax_no_use
=
jnp
.
zeros
(
1
,
jnp
.
float32
)
value_n_grad_primitive_func
=
value_and_grad
(
lambda
a
,
b
,
c
,
d
:
jnp
.
mean
(
_prim_func
(
a
,
b
,
c
,
d
)),
(
0
,
1
,
2
,
3
)
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
)
=
value_n_grad_primitive_func
(
x
,
w
,
contracting_dims
)
return
value_n_grad_primitive_func
(
x
,
dx_trans_no_use
,
dbias_no_use
,
amax_no_use
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"shape"
,
[(
32
,
1
,
64
),
(
16
,
64
,
1
,
256
)])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[
(
"gelu"
,),
(
"gelu"
,
"linear"
),
(
"silu"
,),
(
"silu"
,
"linear"
),
(
"relu"
,),
(
"relu"
,
"linear"
),
(
"quick_gelu"
,),
(
"quick_gelu"
,
"linear"
),
(
"squared_relu"
,),
(
"squared_relu"
,
"linear"
),
],
)
def
test_activation_lu
(
self
,
random_inputs
,
activation_type
):
self
.
amax
=
jnp
.
zeros
(
1
,
jnp
.
float32
)
self
.
scale
=
jnp
.
ones
(
1
,
jnp
.
float32
)
self
.
scale_inv
=
jnp
.
ones
(
1
,
jnp
.
float32
)
self
.
activation_type
=
activation_type
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
def
test_dense_grad_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
):
layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
key
=
jax
.
random
.
PRNGKey
(
1
)
bias
=
jax
.
random
.
uniform
(
key
,
n
,
dtype
=
jnp
.
bfloat16
)
def
primitive_func
(
x
,
w
,
bias
,
contracting_dims
,
quantizer_set
):
primitive_out
=
dense
(
x
,
w
,
bias
,
contracting_dims
=
contracting_dims
,
quantizer_set
=
quantizer_set
)
return
jnp
.
mean
(
primitive_out
)
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
axes
=
jnp
.
arange
(
x
.
ndim
)
self
.
transpose_axes
=
tuple
([
*
axes
[
-
2
:]]
+
[
*
axes
[:
-
2
]])
print
(
self
.
transpose_axes
)
def
ref_func
(
x
,
w
,
bias
,
layout
):
return
jnp
.
mean
(
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
+
jnp
.
expand_dims
(
bias
,
axis
=
0
)
)
prim_out
,
(
prim
_grad
,
prim
_grad_trans
,
dbias
,
amax
)
=
self
.
prim_func
(
x
)
ref_out
,
(
ref_grad
,)
=
self
.
ref_func
(
x
,
activation_type
)
value_n
_grad
_
prim
itive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
)
)
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
)
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
FP8Helper
.
FWD_DTYPE
)
assert_allclose
(
amax
,
jnp
.
amax
(
jnp
.
abs
(
ref_grad
)),
rtol
=
1e-2
)
if
"linear"
not
in
activation_type
:
assert_allclose
(
dbias
,
jnp
.
sum
(
ref_grad
,
axis
=
(
i
for
i
in
range
(
x
.
ndim
-
1
))))
assert_allclose
(
prim_grad
,
ref_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
assert_allclose
(
prim_grad_trans
,
jnp
.
transpose
(
ref_grad
,
self
.
transpose_axes
),
dtype
=
FP8Helper
.
BWD_DTYPE
,
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
True
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
,
primitive_bias_grad
)
=
(
value_n_grad_primitive_func
(
x
,
w
,
bias
,
contracting_dims
,
quantizer_set
)
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
,
ref_bias_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
bias
,
layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_bias_grad
,
ref_bias_grad
,
dtype
=
q_dtype
)
class
TestNorm
:
"""
Test transformer_engine.jax.layernorm APIs
"""
@
staticmethod
def
_generate_fp8_meta
():
fp8_dtype_list
=
[
FP8Helper
.
FWD_DTYPE
,
FP8Helper
.
FWD_DTYPE
,
FP8Helper
.
BWD_DTYPE
]
amax_list
=
[
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
]
scale_list
=
[
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
]
return
fp8_dtype_list
,
amax_list
,
scale_list
def
reference_layernorm
(
self
,
x
,
scale
,
bias
,
zero_centered_gamma
,
eps
):
@
pytest
.
fixture
(
name
=
"random_inputs"
)
def
random_inputs_fixture
(
shape
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
4
)
out
=
jax
.
random
.
uniform
(
subkeys
[
0
],
shape
,
jnp
.
bfloat16
,
5
,
8
)
return
out
def
_ref_jax_norm_impl
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
):
if
norm_type
==
"rmsnorm"
:
ln_out
,
_
=
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
eps
,
quantizer
)
else
:
ln_out
,
_
,
_
=
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
eps
,
quantizer
)
if
isinstance
(
ln_out
,
ScaledTensor
):
ln_out
=
ln_out
.
dequantize
()
return
ln_out
class
TestFusedDense
:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
512
,
128
,
128
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
def
test_layernorm_dense_grad
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
norm_type
):
"""
JAX native layernorm implementations
- bias is not None: layernorm
- bias is None: rmsnorm
Test layernorm_dense VJP Rule
"""
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
if
bias
is
None
:
mean
=
0.0
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
zero_centered_gamma
=
False
eps
=
1e-6
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
4
)
# NN in FWD
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
)).
astype
(
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
k
)
w
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
n
)).
astype
(
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
n
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
k
,)).
astype
(
jnp
.
bfloat16
)
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
True
,
)
if
norm_type
==
"layernorm"
:
beta
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
k
,)).
astype
(
jnp
.
bfloat16
)
else
:
mean
=
jnp
.
mean
(
x_
,
axis
=-
1
,
keepdims
=
True
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
normed_input
=
(
x_
-
mean
)
*
jax
.
lax
.
rsqrt
(
var
+
eps
)
if
zero_centered_gamma
:
scale
+=
1.0
if
bias
is
None
:
bias
=
0.0
return
jnp
.
asarray
(
normed_input
*
scale
+
bias
).
astype
(
x
.
dtype
)
@
pytest
.
mark
.
parametrize
(
"n, hidden"
,
LN_CASES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"ln_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"epsilon"
,
[
1e-2
,
1e-6
])
def
test_layernorm_forward_backward
(
self
,
n
,
hidden
,
ln_type
,
zero_centered_gamma
,
epsilon
,
dtype
beta
=
None
def
prim_func
(
x
,
w
,
gamma
,
beta
):
# bias = None as quantize_dbias is already tested in test_dense_grad_fp8
prim_out
=
layernorm_dense
(
x
,
w
,
gamma
,
beta
,
None
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer_set
=
quantizer_set
,
)
return
jnp
.
mean
(
prim_out
)
def
ref_func
(
x
,
w
,
gamma
,
beta
):
x
=
_ref_jax_norm_impl
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
=
None
)
return
jnp
.
mean
(
jnp
.
dot
(
x
,
w
))
value_n_grad_prim_func
=
value_and_grad
(
prim_func
,
(
0
,
1
,
2
,
3
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
,
3
))
ref_out
,
(
ref_x_grad
,
ref_w_grad
,
ref_gamma_grad
,
ref_beta_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
gamma
,
beta
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_x_grad
,
prim_w_grad
,
prim_gamma_grad
,
prim_beta_grad
,
)
=
value_n_grad_prim_func
(
x
,
w
,
gamma
,
beta
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_w_grad
,
ref_w_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
q_dtype
)
if
beta
is
not
None
:
assert_allclose
(
prim_beta_grad
,
ref_beta_grad
,
dtype
=
q_dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_layernorm_mlp_grad
(
self
,
m
,
n
,
k
,
activation_type
,
q_dtype
,
scaling_mode
,
norm_type
,
use_bias
):
"""
Test
transformer_engine.jax.layernorm.layernorm
Test
layernorm_mlp VJP Rule
"""
expect_assert
=
False
if
ln_type
==
"rmsnorm"
and
zero_centered_gamma
:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert
=
True
with
(
pytest
.
raises
(
AssertionError
,
match
=
r
".*zero_centered_gamma is not supported.*"
)
if
expect_assert
else
nullcontext
()
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
3
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
(
n
,
hidden
),
dtype
,
-
1
,
1
)
gamma_range
=
(
-
1
,
1
)
if
zero_centered_gamma
else
(
0
,
2
)
gamma
=
jax
.
random
.
uniform
(
subkeys
[
1
],
(
hidden
,),
jnp
.
float32
,
*
gamma_range
)
gamma
=
jnp
.
asarray
(
gamma
,
dtype
)
if
ln_type
==
"layernorm"
:
beta
=
jax
.
random
.
uniform
(
subkeys
[
2
],
(
hidden
,),
jnp
.
float32
,
-
1
,
1
)
beta
=
jnp
.
asarray
(
beta
,
dtype
)
else
:
beta
=
None
def
compute_loss
(
x
):
# Higher precision to compute the loss
x_
=
x
.
astype
(
jnp
.
float32
)
return
jnp
.
mean
(
jnp
.
square
(
x_
)).
astype
(
x
.
dtype
)
jitted_primitive
=
jit
(
value_and_grad
(
lambda
x
,
gamma
,
beta
:
compute_loss
(
layernorm
(
x
,
gamma
,
beta
,
ln_type
,
zero_centered_gamma
,
epsilon
)
),
(
0
,
1
,
2
),
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
zero_centered_gamma
=
False
eps
=
1e-6
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
6
)
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
kernel_1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
len
(
activation_type
)
*
n
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
k
)
kernel_2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
n
,
k
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
n
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
k
,),
jnp
.
bfloat16
)
beta
=
None
# was tested in TestNorm
if
use_bias
:
bias_1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
*
n
),
jnp
.
bfloat16
)
bias_2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
k
,),
jnp
.
bfloat16
)
else
:
bias_1
=
None
bias_2
=
None
quantizer_sets
=
QuantizerFactory
.
create_set
(
n_quantizer_sets
=
2
,
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
True
,
)
if
norm_type
==
"layernorm"
:
beta
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
k
,)).
astype
(
jnp
.
bfloat16
)
else
:
beta
=
None
def
prim_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
):
return
jnp
.
mean
(
layernorm_mlp
(
x
,
gamma
,
beta
,
[
kernel_1
,
kernel_2
],
[
bias_1
,
bias_2
],
norm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
eps
,
activation_type
=
activation_type
,
quantizer_sets
=
quantizer_sets
,
)
)
jitted_reference
=
jit
(
value_and_grad
(
lambda
x
,
gamma
,
beta
:
compute_loss
(
self
.
reference_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
)
),
(
0
,
1
,
2
),
)
def
_ref_func_impl
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
):
ln_out
=
_ref_jax_norm_impl
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
=
None
)
# TODO: replace gemm with jnp.dot
linear_1_out
=
tex
.
gemm
(
ln_out
,
kernel_1
,
((
1
,),
(
0
,)))
if
use_bias
:
bias_1_shape
=
(
1
,)
*
(
linear_1_out
.
ndim
-
bias_1
.
ndim
)
+
bias_1
.
shape
linear_1_out
+=
jnp
.
reshape
(
bias_1
,
bias_1_shape
)
x
=
_jax_act_lu
(
linear_1_out
,
activation_type
)
linear_2_out
=
tex
.
gemm
(
x
,
kernel_2
,
((
1
,),
(
0
,)))
if
use_bias
:
bias_2_shape
=
(
1
,)
*
(
linear_2_out
.
ndim
-
bias_2
.
ndim
)
+
bias_2
.
shape
linear_2_out
+=
jnp
.
reshape
(
bias_2
,
bias_2_shape
)
return
linear_2_out
def
ref_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
):
return
jnp
.
mean
(
_ref_func_impl
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
))
value_n_grad_prim_func
=
value_and_grad
(
prim_func
,
range
(
6
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
range
(
6
))
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_x_grad
,
prim_gamma_grad
,
prim_kernel_1_grad
,
prim_kernel_2_grad
,
prim_bias_1_grad
,
prim_bias_2_grad
,
)
=
value_n_grad_prim_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
)
ref_out
,
(
ref_x_grad
,
ref_gamma_grad
,
ref_kernel_1_grad
,
ref_kernel_2_grad
,
ref_bias_1_grad
,
ref_bias_2_grad
,
)
=
value_n_grad_ref_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
prim_kernel_2_grad
,
ref_kernel_2_grad
,
dtype
=
q_dtype
)
if
use_bias
:
assert_allclose
(
prim_bias_2_grad
,
ref_bias_2_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_kernel_1_grad
,
ref_kernel_1_grad
,
dtype
=
q_dtype
)
if
use_bias
:
assert_allclose
(
prim_bias_1_grad
,
ref_bias_1_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
# This function is modified from transformer_engine/jax/cpp_extensions/gemm.py::_jax_gemm()
def
_quantize_gemm_pair
(
lhs
,
rhs
,
contracting_dims
,
lhs_quantizer
,
rhs_quantizer
):
((
lhs_contract_dim
,),
(
rhs_contract_dim
,))
=
contracting_dims
lhs_is_rowwise
=
lhs_contract_dim
==
lhs
.
ndim
-
1
rhs_is_rowwise
=
rhs_contract_dim
==
rhs
.
ndim
-
1
lhs_q
=
lhs_quantizer
.
quantize
(
lhs
,
is_rowwise
=
lhs_is_rowwise
,
is_colwise
=
not
lhs_is_rowwise
,
)
rhs_q
=
rhs_quantizer
.
quantize
(
rhs
,
is_rowwise
=
rhs_is_rowwise
,
is_colwise
=
not
rhs_is_rowwise
,
)
return
lhs_q
,
rhs_q
primitive_out
,
(
primitive_dx
,
primitive_dgamma
,
primitive_dbeta
)
=
jitted_primitive
(
x
,
gamma
,
beta
# E5M2 * E5M2 is not supported
fwd_bwd_dtypes
=
[
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e4m3fn
],
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
],
[
jnp
.
float8_e5m2
,
jnp
.
float8_e4m3fn
],
]
@
pytest_parametrize_wrapper
(
"shape_list"
,
[[(
512
,
128
,
256
),
(
256
,
128
,
256
),
(
256
,
128
,
128
),
(
512
,
256
,
128
)]]
)
class
TestGroupedDense
:
def
_ref_grouped_gemm_with_jnp_dot
(
self
,
lhs_list
,
rhs_list
,
contracting_dims_list
):
ref_out_list
=
[]
for
lhs
,
rhs
,
contracting_dims
in
zip
(
lhs_list
,
rhs_list
,
contracting_dims_list
):
dim_nums
=
(
contracting_dims
,
((),
()))
ref_out_list
.
append
(
jax
.
lax
.
dot_general
(
lhs
,
rhs
,
dim_nums
))
return
ref_out_list
def
_generate_grouped_gemm_input
(
self
,
dtype
,
shape_list
,
layout_list
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
len
(
shape_list
)
*
2
)
lhs_list
,
rhs_list
,
contracting_dims_list
=
[],
[],
[]
for
i
,
((
m
,
n
,
k
),
layout
)
in
enumerate
(
zip
(
shape_list
,
layout_list
)):
lhs
=
jax
.
random
.
uniform
(
subkeys
[
2
*
i
],
(
m
if
layout
[
0
]
==
"N"
else
k
,
k
if
layout
[
0
]
==
"N"
else
m
),
dtype
=
dtype
,
)
reference_out
,
(
reference_dx
,
reference_dgamma
,
reference_dbeta
)
=
jitted_reference
(
x
,
gamma
,
beta
rhs
=
jax
.
random
.
uniform
(
subkeys
[
2
*
i
+
1
],
(
k
if
layout
[
1
]
==
"N"
else
n
,
n
if
layout
[
1
]
==
"N"
else
k
),
dtype
=
dtype
,
)
lhs_contracting_dim
=
(
1
,)
if
layout
[
0
]
==
"N"
else
(
0
,)
rhs_contracting_dim
=
(
0
,)
if
layout
[
1
]
==
"N"
else
(
1
,)
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
assert_allclose
(
primitive_out
,
reference_out
,
dtype
=
dtype
)
assert_allclose
(
primitive_dx
,
reference_dx
,
dtype
=
dtype
)
assert_allclose
(
primitive_dgamma
,
reference_dgamma
,
dtype
=
dtype
)
if
beta
is
not
None
:
assert_allclose
(
primitive_dbeta
,
reference_dbeta
,
dtype
=
dtype
)
lhs_list
.
append
(
lhs
)
rhs_list
.
append
(
rhs
)
contracting_dims_list
.
append
(
contracting_dims
)
return
lhs_list
,
rhs_list
,
contracting_dims_list
@
pytest_parametrize_wrapper
(
"dtype"
,
[
jnp
.
bfloat16
,
jnp
.
float16
])
@
pytest_parametrize_wrapper
(
"layout_list"
,
[[
"NN"
,
"TN"
,
"NT"
,
"TT"
]])
def
test_grouped_gemm_fp16
(
self
,
dtype
,
shape_list
,
layout_list
):
lhs_list
,
rhs_list
,
contracting_dims_list
=
self
.
_generate_grouped_gemm_input
(
dtype
,
shape_list
,
layout_list
)
ref_out
=
self
.
_ref_grouped_gemm_with_jnp_dot
(
lhs_list
,
rhs_list
,
contracting_dims_list
)
primitive_out
=
tex
.
grouped_gemm
(
lhs_list
,
rhs_list
,
contracting_dims_list
)
for
i
in
range
(
len
(
shape_list
)):
assert_allclose
(
primitive_out
[
i
],
ref_out
[
i
],
dtype
=
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
GEMM_CASES
)
@
pytest
.
mark
.
parametrize
(
"ln_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"epsilon"
,
[
1e-2
,
1e-6
])
def
test_ln_fp8_dot_forward_backward
(
self
,
m
,
n
,
k
,
ln_type
,
zero_centered_gamma
,
epsilon
):
"""
Test transformer_engine.jax.layernorm.layernorm_fp8_dot
"""
expect_assert
=
False
if
ln_type
==
"rmsnorm"
and
zero_centered_gamma
:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert
=
True
with
(
pytest
.
raises
(
AssertionError
,
match
=
r
".*zero_centered_gamma is not supported.*"
)
if
expect_assert
else
nullcontext
()
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
4
)
@
pytest
.
mark
.
parametrize
(
"fwd_bwd_dtype"
,
fwd_bwd_dtypes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"layout_list"
,
[[
"NN"
,
"TN"
,
"NT"
,
"TT"
]])
def
test_grouped_gemm_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
shape_list
,
layout_list
):
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
fwd_dtype
,
bwd_dtype
=
bwd_dtype
,
is_2x2x
=
False
)
a
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
)).
astype
(
jnp
.
bfloat16
)
b
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
n
)).
astype
(
jnp
.
bfloat16
)
out_dtype
=
jnp
.
bfloat16
lhs_list
,
rhs_list
,
contracting_dims_list
=
self
.
_generate_grouped_gemm_input
(
out_dtype
,
shape_list
,
layout_list
)
q_lhs_list
=
[]
q_rhs_list
=
[]
for
lhs
,
rhs
,
contracting_dims
in
zip
(
lhs_list
,
rhs_list
,
contracting_dims_list
):
# quantizer_set.x and quantizer_set.kernel have the same q_dtype, we want to
# test the case where lhs and rhs have different q_dtypes
q_lhs
,
q_rhs
=
_quantize_gemm_pair
(
lhs
,
rhs
,
contracting_dims
,
quantizer_set
.
x
,
quantizer_set
.
dgrad
)
q_lhs_list
.
append
(
q_lhs
)
q_rhs_list
.
append
(
q_rhs
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
k
,)).
astype
(
jnp
.
bfloat16
)
if
ln_type
==
"layernorm"
:
beta
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
k
,)).
astype
(
jnp
.
bfloat16
)
else
:
beta
=
None
_
,
amax_list_1
,
scale_list_1
=
TestNorm
.
_generate_fp8_meta
()
def
primitive_func
(
x
,
y
,
gamma
,
beta
,
amax_list_1
,
scale_list_1
):
fp8_meta_pkg
=
FP8MetaPackage
(
amax_list_1
[
0
],
scale_list_1
[
0
],
amax_list_1
[
1
],
scale_list_1
[
1
],
amax_list_1
[
2
],
scale_list_1
[
2
],
)
primitive_out
=
layernorm_fp8_dot
(
x
,
y
,
gamma
,
beta
,
fp8_meta_pkg
,
ln_type
,
zero_centered_gamma
ref_out
=
self
.
_ref_grouped_gemm_with_jnp_dot
(
lhs_list
,
rhs_list
,
contracting_dims_list
)
primitive_out
=
tex
.
grouped_gemm
(
q_lhs_list
,
q_rhs_list
,
contracting_dims_list
)
allclose_dtype
=
jnp
.
float8_e4m3fn
if
fwd_dtype
==
jnp
.
float8_e5m2
or
bwd_dtype
==
jnp
.
float8_e5m2
:
allclose_dtype
=
jnp
.
float8_e5m2
for
i
in
range
(
len
(
shape_list
)):
assert_allclose
(
primitive_out
[
i
],
ref_out
[
i
],
dtype
=
allclose_dtype
)
@
pytest_parametrize_wrapper
(
"dtype"
,
[
jnp
.
bfloat16
,
jnp
.
float16
])
def
test_grouped_dense_grad_fp16
(
self
,
dtype
,
shape_list
):
group_size
=
len
(
shape_list
)
layout_list
=
[
"NN"
for
_
in
range
(
group_size
)]
x_list
,
kernel_list
,
contracting_dims_list
=
self
.
_generate_grouped_gemm_input
(
dtype
,
shape_list
,
layout_list
)
bias_list
=
[]
key
=
jax
.
random
.
PRNGKey
(
1
)
for
shape
in
shape_list
:
n
=
shape
[
1
]
bias
=
jax
.
random
.
uniform
(
key
,
n
,
dtype
=
dtype
)
bias_list
.
append
(
bias
)
def
ref_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
):
out_list
=
[]
for
i
in
range
(
len
(
x_list
)):
out_list
.
append
(
dense
(
x_list
[
i
],
kernel_list
[
i
],
bias_list
[
i
],
contracting_dims
=
contracting_dims_list
[
i
],
)
)
return
jnp
.
mean
(
primitive_out
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
def
primitive_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
):
out_list
=
grouped_dense
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
)
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
))
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
))
def
ref_func
(
x
,
y
,
gamma
,
beta
,
zero_centered_gamma
):
x
=
self
.
reference_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
)
return
jnp
.
mean
(
jnp
.
dot
(
x
,
y
))
ref_out_mean
,
(
ref_dgrad_list
,
ref_wgrad_list
,
ref_dbias_list
)
=
value_n_grad_ref_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
)
primitive_out_mean
,
(
primitive_dgrad_list
,
primitive_wgrad_list
,
primitive_dbias_list
)
=
(
value_n_grad_primitive_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
)
)
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
range
(
6
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
,
3
))
assert_allclose
(
primitive_out_mean
,
ref_out_mean
,
dtype
=
dtype
)
for
i
in
range
(
group_size
):
assert_allclose
(
primitive_dgrad_list
[
i
],
ref_dgrad_list
[
i
],
dtype
=
dtype
)
assert_allclose
(
primitive_wgrad_list
[
i
],
ref_wgrad_list
[
i
],
dtype
=
dtype
)
assert_allclose
(
primitive_dbias_list
[
i
],
ref_dbias_list
[
i
],
dtype
=
dtype
)
ref_out
,
(
ref_a_grad
,
ref_b_grad
,
ref_gamma_grad
,
ref_beta_grad
)
=
(
value_n_grad_ref_func
(
a
,
b
,
gamma
,
beta
,
zero_centered_gamma
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fwd_bwd_dtype"
,
fwd_bwd_dtypes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
def
test_grouped_dense_grad_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
shape_list
):
group_size
=
len
(
shape_list
)
layout_list
=
[
"NN"
for
_
in
range
(
group_size
)]
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
if
fwd_dtype
==
jnp
.
float8_e5m2
:
pytest
.
skip
(
"We never use E5M2 for fwd_dtype in training"
)
# Question: should we use different quantizers for different groups?
ref_quantizer_set_list
=
[]
quantizer_set_list
=
[]
for
_
in
range
(
group_size
):
ref_quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
fwd_dtype
,
bwd_dtype
=
bwd_dtype
,
is_2x2x
=
True
)
ref_quantizer_set_list
.
append
(
ref_quantizer_set
)
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
fwd_dtype
,
bwd_dtype
=
bwd_dtype
,
is_2x2x
=
True
)
quantizer_set_list
.
append
(
quantizer_set
)
for
_
in
range
(
3
):
primitive_out
,
(
primitive_a_grad
,
primitive_b_grad
,
primitive_gamma_grad
,
primitive_beta_grad
,
amax_list_1
,
scale_list_1
,
)
=
value_n_grad_primitive_func
(
a
,
b
,
gamma
,
beta
,
amax_list_1
,
scale_list_1
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
FP8Helper
.
FWD_DTYPE
)
assert_allclose
(
primitive_a_grad
,
ref_a_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
assert_allclose
(
primitive_b_grad
,
ref_b_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
assert_allclose
(
primitive_gamma_grad
,
ref_gamma_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
if
beta
is
not
None
:
assert_allclose
(
primitive_beta_grad
,
ref_beta_grad
,
dtype
=
FP8Helper
.
BWD_DTYPE
)
@
pytest
.
mark
.
parametrize
(
"in_dtype"
,
[
pytest
.
param
(
jnp
.
float32
,
id
=
"input_float32"
),
pytest
.
param
(
jnp
.
float16
,
id
=
"input_float16"
),
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"input_bfloat16"
),
],
)
@
pytest
.
mark
.
parametrize
(
"input_shape, transpose_axis"
,
[
pytest
.
param
((
16
,
16
),
1
,
id
=
"(16, 16)-1"
),
pytest
.
param
((
256
,
128
),
1
,
id
=
"(256, 128)-1"
),
pytest
.
param
((
128
,
512
),
1
,
id
=
"(128, 512)-1"
),
pytest
.
param
((
64
,
16
,
4
,
256
),
1
,
id
=
"(64, 16, 4, 256)-1"
),
pytest
.
param
((
64
,
16
,
4
,
256
),
2
,
id
=
"(64, 16, 4, 256)-2"
),
pytest
.
param
((
64
,
16
,
4
,
256
),
3
,
id
=
"(64, 16, 4, 256)-3"
),
],
)
class
TestTranspose
:
def
test_transpose
(
self
,
in_dtype
,
input_shape
,
transpose_axis
):
key
=
jax
.
random
.
PRNGKey
(
0
)
input_tensor
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
static_axis_boundary
=
-
1
jax_output
=
_jax_transpose
(
input_tensor
,
static_axis_boundary
,
transpose_axis
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"0"
noffi_output
=
tex
.
transpose
(
input_tensor
,
static_axis_boundary
,
transpose_axis
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"1"
ffi_output
=
tex
.
transpose
(
input_tensor
,
static_axis_boundary
,
transpose_axis
)
assert_allclose
(
jax_output
,
noffi_output
)
assert_allclose
(
noffi_output
,
ffi_output
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
pytest
.
param
(
jnp
.
float8_e4m3fn
,
id
=
"output_float8_e4m3fn"
),
pytest
.
param
(
jnp
.
float8_e5m2
,
id
=
"output_float8_e5m2"
),
],
)
def
test_cast_transpose
(
self
,
in_dtype
,
input_shape
,
transpose_axis
,
out_dtype
):
amax
=
jnp
.
zeros
(
1
,
jnp
.
float32
)
scale
=
jnp
.
ones
(
1
,
jnp
.
float32
)
scale_inv
=
jnp
.
ones
(
1
,
jnp
.
float32
)
key
=
jax
.
random
.
PRNGKey
(
0
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
static_axis_boundary
=
-
1
jax_output
=
_jax_cast_transpose
(
input
,
scale
,
amax
,
out_dtype
,
static_axis_boundary
,
transpose_axis
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"0"
noffi_output
=
tex
.
cast_transpose
(
input
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
transpose_axis
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"1"
ffi_output
=
tex
.
cast_transpose
(
input
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
transpose_axis
)
assert_tree_like_allclose
(
jax_output
,
ffi_output
)
assert_tree_like_allclose
(
noffi_output
,
ffi_output
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
pytest
.
param
(
jnp
.
float8_e4m3fn
,
id
=
"output_float8_e4m3fn"
),
pytest
.
param
(
jnp
.
float8_e5m2
,
id
=
"output_float8_e5m2"
),
],
)
def
test_dbias_cast_transpose
(
self
,
in_dtype
,
input_shape
,
transpose_axis
,
out_dtype
):
amax
=
jnp
.
zeros
(
1
,
jnp
.
float32
)
scale
=
jnp
.
ones
(
1
,
jnp
.
float32
)
scale_inv
=
jnp
.
ones
(
1
,
jnp
.
float32
)
key
=
jax
.
random
.
PRNGKey
(
0
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
static_axis_boundary
=
-
1
jax_output
=
_jax_dbias_cast_transpose
(
input
,
amax
,
scale
,
out_dtype
,
static_axis_boundary
,
transpose_axis
out_dtype
=
jnp
.
bfloat16
x_list
,
kernel_list
,
contracting_dims_list
=
self
.
_generate_grouped_gemm_input
(
out_dtype
,
shape_list
,
layout_list
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"0"
noffi_output
=
tex
.
dbias_cast_transpose
(
input
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
transpose_axis
bias_list
=
[]
key
=
jax
.
random
.
PRNGKey
(
1
)
for
shape
in
shape_list
:
n
=
shape
[
1
]
bias
=
jax
.
random
.
uniform
(
key
,
n
,
dtype
=
out_dtype
)
bias_list
.
append
(
bias
)
def
ref_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
):
out_list
=
[]
for
i
in
range
(
len
(
x_list
)):
out_list
.
append
(
dense
(
x_list
[
i
],
kernel_list
[
i
],
bias_list
[
i
],
contracting_dims
=
contracting_dims_list
[
i
],
quantizer_set
=
quantizer_set_list
[
i
],
)
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
def
primitive_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
):
out_list
=
grouped_dense
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
)
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
))
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
))
ref_out_mean
,
(
ref_dgrad_list
,
ref_wgrad_list
,
ref_dbias_list
)
=
value_n_grad_ref_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
ref_quantizer_set_list
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"1"
ffi_output
=
tex
.
dbias_cast_transpose
(
input
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
transpose_axis
primitive_out_mean
,
(
primitive_dgrad_list
,
primitive_wgrad_list
,
primitive_dbias_list
)
=
(
value_n_grad_primitive_func
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
)
)
assert_tree_like_allclose
(
jax_output
,
ffi_output
)
assert_tree_like_allclose
(
noffi_output
,
ffi_output
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"input_shape"
,
[
pytest
.
param
((
256
,
128
),
id
=
"(256, 128)"
),
pytest
.
param
((
128
,
512
,
8
),
id
=
"(128, 512, 8)"
),
],
)
@
pytest
.
mark
.
parametrize
(
"in_dtype"
,
[
pytest
.
param
(
jnp
.
float32
,
id
=
"input_float32"
),
pytest
.
param
(
jnp
.
float16
,
id
=
"input_float16"
),
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"input_bfloat16"
),
],
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
pytest
.
param
(
jnp
.
float8_e4m3fn
,
id
=
"output_float8_e4m3fn"
),
pytest
.
param
(
jnp
.
float8_e5m2
,
id
=
"output_float8_e5m2"
),
],
)
def
test_quantize
(
input_shape
,
in_dtype
,
out_dtype
):
amax
=
jnp
.
zeros
(
1
,
jnp
.
float32
)
scale
=
jnp
.
ones
(
1
,
jnp
.
float32
)
scale_inv
=
jnp
.
ones
(
1
,
jnp
.
float32
)
key
=
jax
.
random
.
PRNGKey
(
0
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
jax_output
=
_jax_cast_fp8
(
input
,
scale
,
amax
,
out_dtype
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"0"
noffi_output
=
tex
.
cast_fp8
(
input
,
amax
,
scale
,
scale_inv
,
out_dtype
)
os
.
environ
[
"NVTE_JAX_WITH_FFI"
]
=
"1"
ffi_output
=
tex
.
cast_fp8
(
input
,
amax
,
scale
,
scale_inv
,
out_dtype
)
assert_tree_like_allclose
(
jax_output
,
ffi_output
)
assert_tree_like_allclose
(
noffi_output
,
ffi_output
)
allclose_dtype
=
jnp
.
float8_e4m3fn
if
fwd_dtype
==
jnp
.
float8_e5m2
or
bwd_dtype
==
jnp
.
float8_e5m2
:
allclose_dtype
=
jnp
.
float8_e5m2
assert_allclose
(
primitive_out_mean
,
ref_out_mean
,
dtype
=
allclose_dtype
)
for
i
in
range
(
group_size
):
assert_allclose
(
primitive_dgrad_list
[
i
],
ref_dgrad_list
[
i
],
dtype
=
allclose_dtype
)
assert_allclose
(
primitive_wgrad_list
[
i
],
ref_wgrad_list
[
i
],
dtype
=
allclose_dtype
)
assert_allclose
(
primitive_dbias_list
[
i
],
ref_dbias_list
[
i
],
dtype
=
allclose_dtype
)
tests/jax/test_distributed_fused_attn.py
View file @
a207db1d
...
...
@@ -6,7 +6,6 @@ import os
import
pytest
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
from
jax
import
random
from
distributed_test_base
import
(
generate_configs
,
...
...
@@ -104,7 +103,7 @@ class TestDistributedSelfAttn:
hidden
,
None
,
# no window
):
pytest
.
skip
(
f
"No FusedAttn backend found"
)
pytest
.
skip
(
"No FusedAttn backend found"
)
col_ref
=
self
.
generate_collectives_count_ref
(
mesh_shape
,
...
...
@@ -176,7 +175,7 @@ class TestDistributedCrossAttn:
hidden
,
None
,
# no window
):
pytest
.
skip
(
f
"No FusedAttn backend found"
)
pytest
.
skip
(
"No FusedAttn backend found"
)
col_ref
=
self
.
generate_collectives_count_ref
()
runner
=
FusedAttnRunner
(
...
...
@@ -256,7 +255,6 @@ class TestDistributedContextParallelSelfAttn:
dropout_prob
=
0.0
is_training
=
True
dp_size
,
cp_size
,
tp_size
=
mesh_shape
qkv_format
=
qkv_layout
.
get_qkv_format
()
batch
,
seqlen
,
num_head
,
hidden
=
data_shape
...
...
@@ -382,7 +380,7 @@ class TestDistributedContextParallelSelfAttn:
if
qkv_layout
.
is_thd
()
and
not
load_balanced
:
pytest
.
skip
(
"THD + ring doesn't support unbalanced context parallelism."
)
return
self
.
impl_test_context_parallel_attn
(
self
.
impl_test_context_parallel_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
...
...
@@ -396,6 +394,7 @@ class TestDistributedContextParallelSelfAttn:
CPStrategy
.
RING
,
)
del
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
return
class
TestReorderCausalLoadBalancing
:
...
...
tests/jax/test_distributed_layernorm.py
View file @
a207db1d
...
...
@@ -13,11 +13,30 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec
from
distributed_test_base
import
generate_configs
,
generate_collectives_count
from
distributed_test_base
import
compare_ops
from
utils
import
pytest_parametrize_wrapper
from
transformer_engine.jax
import
fp8_autocast
from
transformer_engine.common
import
recipe
from
transformer_engine.jax.layernorm
import
layernorm
from
transformer_engine.jax.quantize
import
QuantizerFactory
,
ScalingMode
,
is_fp8_available
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
NORM_INPUT_SHAPES
=
{
"L0"
:
[[
64
,
64
]],
"L2"
:
[[
64
,
64
]],
}
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
if
is_mxfp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
class
TestDistributedLayernorm
:
...
...
@@ -41,25 +60,32 @@ class TestDistributedLayernorm:
return
(
x
,
gamma
,
beta
),
(
x_pspec
,
g_pspec
,
b_pspec
)
def
generate_collectives_count_ref
(
self
,
mesh_resource
,
ln_type
,
shape
,
dtype
):
def
generate_collectives_count_ref
(
self
,
mesh_resource
,
ln_type
,
shape
,
dtype
,
mesh_axes
,
fp8_recipe
):
jax_dtype
=
jax
.
dtypes
.
canonicalize_dtype
(
dtype
)
is_dp_enabled
=
mesh_resource
.
dp_resource
is
not
None
assert
ln_type
in
[
"layernorm"
,
"rmsnorm"
]
all_reduce_loss_bytes
=
4
# 1 * FP32
# for loss, dgamma and dbeta
weight_count
=
2
if
ln_type
==
"layernorm"
else
1
# TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
weight_count
=
2
if
(
ln_type
==
"layernorm"
and
"dp"
in
mesh_axes
)
else
1
allreduce_total_bytes
=
(
all_reduce_loss_bytes
+
weight_count
*
shape
[
-
1
]
*
jax_dtype
.
itemsize
)
other_bytes
=
0
if
fp8_recipe
==
recipe
.
MXFP8BlockScaling
()
and
"dp"
in
mesh_axes
:
other_bytes
=
384
# required for small scale shapes that require padding
return
generate_collectives_count
(
allreduce
=
allreduce_total_bytes
*
int
(
is_dp_enabled
),
allgather
=
0
,
other
=
0
allreduce
=
allreduce_total_bytes
*
int
(
is_dp_enabled
),
allgather
=
0
,
other
=
other_bytes
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
128
,
1024
],
[
32
,
1024
]])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"data_shape"
,
NORM_INPUT_SHAPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm
(
self
,
device_count
,
...
...
@@ -70,12 +96,19 @@ class TestDistributedLayernorm:
dtype
,
zero_centered_gamma
,
shard_weights
,
fp8_recipe
,
):
epsilon
=
1e-6
ln_type
=
"layernorm"
q_dtype
=
jnp
.
float8_e4m3fn
def
target_func
(
x
,
gamma
,
beta
):
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
beta
,
ln_type
,
zero_centered_gamma
,
epsilon
))
quantizer
=
QuantizerFactory
.
create_set
().
x
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
beta
,
ln_type
,
zero_centered_gamma
,
epsilon
,
quantizer
=
quantizer
)
)
def
ref_func
(
x
,
gamma
,
beta
):
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
...
...
@@ -92,11 +125,11 @@ class TestDistributedLayernorm:
data_shape
,
mesh_resource
,
dtype
,
shard_weights
)
collective_count_ref
=
self
.
generate_collectives_count_ref
(
mesh_resource
,
ln_type
,
data_shape
,
dtype
mesh_resource
,
ln_type
,
data_shape
,
dtype
,
mesh_axes
,
fp8_recipe
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
gamma_
=
jax
.
device_put
(
gamma
,
NamedSharding
(
mesh
,
g_pspec
))
beta_
=
jax
.
device_put
(
beta
,
NamedSharding
(
mesh
,
b_pspec
))
...
...
@@ -109,8 +142,8 @@ class TestDistributedLayernorm:
[
x_
,
gamma_
,
beta_
],
collective_count_ref
,
grad_args
=
(
0
,
1
,
2
),
metric_fwd_dtype
=
dtype
,
metric_bwd_dtype
=
dtype
,
metric_fwd_dtype
=
q_
dtype
,
metric_bwd_dtype
=
q_
dtype
,
in_shardings
=
(
x_pspec
,
g_pspec
,
b_pspec
),
out_shardings
=
(
None
,
(
x_pspec
,
g_pspec
,
b_pspec
)),
)
...
...
@@ -131,17 +164,28 @@ class TestDistributedLayernorm:
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
128
,
1024
],
[
32
,
1024
]])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"data_shape"
,
NORM_INPUT_SHAPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_rmsnorm
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
dtype
,
shard_weights
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
dtype
,
shard_weights
,
fp8_recipe
,
):
epsilon
=
1e-6
ln_type
=
"rmsnorm"
q_dtype
=
jnp
.
float8_e4m3fn
def
target_func
(
x
,
gamma
):
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
None
,
ln_type
,
False
,
epsilon
))
quantizer
=
QuantizerFactory
.
create_set
().
x
return
jnp
.
mean
(
layernorm
(
x
,
gamma
,
None
,
ln_type
,
False
,
epsilon
,
quantizer
=
quantizer
))
def
ref_func
(
x
,
gamma
):
x
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
...
...
@@ -154,11 +198,11 @@ class TestDistributedLayernorm:
data_shape
,
mesh_resource
,
dtype
,
shard_weights
)
collective_count_ref
=
self
.
generate_collectives_count_ref
(
mesh_resource
,
ln_type
,
data_shape
,
dtype
mesh_resource
,
ln_type
,
data_shape
,
dtype
,
mesh_axes
,
fp8_recipe
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
x_
=
jax
.
device_put
(
x
,
NamedSharding
(
mesh
,
x_pspec
))
gamma_
=
jax
.
device_put
(
gamma
,
NamedSharding
(
mesh
,
g_pspec
))
...
...
@@ -170,8 +214,8 @@ class TestDistributedLayernorm:
[
x_
,
gamma_
],
collective_count_ref
,
grad_args
=
(
0
,
1
),
metric_fwd_dtype
=
dtype
,
metric_bwd_dtype
=
dtype
,
metric_fwd_dtype
=
q_
dtype
,
metric_bwd_dtype
=
q_
dtype
,
in_shardings
=
(
x_pspec
,
g_pspec
),
out_shardings
=
(
None
,
(
x_pspec
,
g_pspec
)),
)
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Callable
,
Sequence
,
Union
,
Optional
import
pytest
from
typing
import
Callable
,
List
,
Sequence
,
Union
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
from
jax.sharding
import
Mesh
,
NamedSharding
,
PartitionSpec
from
utils
import
(
assert_allclose
,
assert_tree_like_allclose
,
is_devices_enough
,
pytest_parametrize_wrapper
,
)
from
transformer_engine.
jax.fp8
import
FP8MetaPackage
,
FP8Hel
pe
r
from
transformer_engine.jax.
fp8
import
is_fp8_available
from
transformer_engine.
common
import
reci
pe
from
transformer_engine.jax.
quantize
import
is_fp8_available
,
ScalingMode
from
transformer_engine.jax
import
fp8_autocast
from
transformer_engine.jax.flax
import
LayerNormMLP
from
transformer_engine.jax.layernorm_mlp
import
fused_
layernorm_
fp8_
mlp
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
from
transformer_engine.jax.sharding
import
(
HIDDEN_AXES
,
HIDDEN_TP_AXES
,
...
...
@@ -26,17 +32,25 @@ from transformer_engine.jax.sharding import (
W_JOINED_AXES
,
)
from
transformer_engine.jax.sharding
import
MeshResource
from
transformer_engine.jax.quantize
import
QuantizerFactory
from
utils
import
assert_allclose
,
assert_tree_like_allclose
,
is_devices_enough
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
if
is_mxfp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float16
]
INPUT_SHAPE
=
[[
64
,
128
,
32
]]
# [batch, seqlen, hidden_in]
INPUT_SHAPE
=
[[
2
,
64
,
64
]]
# [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_TP_AXES
,
HIDDEN_AXES
)
DOT_1_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_AXES
)
DOT_2_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_TP_AXES
)
INTERMEDIATE
=
1
6
INTERMEDIATE
=
6
4
# Only test with FSDP and TP as DP is not used
...
...
@@ -66,13 +80,13 @@ class TestDistributedLayernormMLP:
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
batch
,
seqlen
,
hidden_in
),
dtype
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
hidden_in
,),
dtype
=
dtype
)
k1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
hidden_in
,
len
(
activation_type
)
,
INTERMEDIATE
),
dtype
subkeys
[
1
],
(
hidden_in
,
len
(
activation_type
)
*
INTERMEDIATE
),
dtype
)
/
jnp
.
sqrt
(
hidden_in
)
k2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
INTERMEDIATE
,
hidden_out
),
dtype
)
/
jnp
.
sqrt
(
INTERMEDIATE
)
if
use_bias
:
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
,
INTERMEDIATE
),
dtype
)
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
*
INTERMEDIATE
),
dtype
)
b2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
hidden_out
,),
dtype
)
else
:
b1
=
None
...
...
@@ -86,35 +100,13 @@ class TestDistributedLayernormMLP:
ln_scale
:
jnp
.
ndarray
,
kernel_1
:
jnp
.
ndarray
,
kernel_2
:
jnp
.
ndarray
,
bias_1
:
jnp
.
ndarray
,
bias_2
:
jnp
.
ndarray
,
amax_list_1
:
List
[
jnp
.
ndarray
],
amax_list_2
:
List
[
jnp
.
ndarray
],
scale_list_1
:
List
[
jnp
.
ndarray
],
scale_list_2
:
List
[
jnp
.
ndarray
],
bias_1
:
Optional
[
jnp
.
ndarray
],
bias_2
:
Optional
[
jnp
.
ndarray
],
layernorm_type
:
str
=
"rmsnorm"
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
use_bias
:
bool
=
True
,
multi_gpus
:
bool
=
False
,
)
->
jnp
.
ndarray
:
fp8_meta_pkg1
=
FP8MetaPackage
(
amax_list_1
[
0
],
scale_list_1
[
0
],
amax_list_1
[
1
],
scale_list_1
[
1
],
amax_list_1
[
2
],
scale_list_1
[
2
],
)
fp8_meta_pkg2
=
FP8MetaPackage
(
amax_list_2
[
0
],
scale_list_2
[
0
],
amax_list_2
[
1
],
scale_list_2
[
1
],
amax_list_2
[
2
],
scale_list_2
[
2
],
)
if
multi_gpus
:
layernorm_input_axes
=
LAYERNORM_INPUT_AXES
dot_1_input_axes
=
DOT_1_INPUT_AXES
...
...
@@ -124,83 +116,64 @@ class TestDistributedLayernormMLP:
dot_1_input_axes
=
None
dot_2_input_axes
=
None
quantizer_sets
=
QuantizerFactory
.
create_set
(
n_quantizer_sets
=
2
)
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return
jnp
.
mean
(
fused_
layernorm_
fp8_
mlp
(
layernorm_mlp
(
x
,
ln_scale
,
None
,
[
kernel_1
,
kernel_2
],
[
bias_1
,
bias_2
],
[
fp8_meta_pkg1
,
fp8_meta_pkg2
],
layernorm_type
,
layer
norm_input_axes
=
layernorm_input_axes
,
norm_input_axes
=
layernorm_input_axes
,
dot_1_input_axes
=
dot_1_input_axes
,
dot_2_input_axes
=
dot_2_input_axes
,
activation_type
=
activation_type
,
use_bias
=
use_bia
s
,
quantizer_sets
=
quantizer_set
s
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
.
mark
.
parametrize
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm_fp8_mlp_primitive
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
):
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
layernorm_type
=
"rmsnorm"
fp8_amax_list_1
=
[
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
]
fp8_amax_list_2
=
[
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
jnp
.
zeros
((
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
),
]
fp8_scale_list_1
=
[
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
]
fp8_scale_list_2
=
[
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
jnp
.
ones
((
1
,),
jnp
.
float32
),
]
inputs
=
[
x
,
gamma
,
k1
,
k2
,
b1
,
b2
]
=
self
.
generate_inputs
(
input_shape
,
activation_type
,
use_bias
,
dtype
)
inputs
=
[
*
inputs
,
fp8_amax_list_1
,
fp8_amax_list_2
,
fp8_scale_list_1
,
fp8_scale_list_2
]
static_inputs
=
[
layernorm_type
,
activation_type
,
use_bias
]
static_inputs
=
[
layernorm_type
,
activation_type
]
value_and_grad_func
=
jax
.
value_and_grad
(
self
.
layernorm_fp8_mlp_prim_func
,
argnums
=
range
(
len
(
inputs
))
)
# Single GPU
single_jitter
=
jax
.
jit
(
value_and_grad_func
,
static_argnums
=
range
(
len
(
inputs
),
len
(
static_inputs
)
+
len
(
inputs
))
)
with
fp8_autocast
(
enabled
=
True
):
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
):
single_jitter
=
jax
.
jit
(
value_and_grad_func
,
static_argnums
=
range
(
len
(
inputs
),
len
(
static_inputs
)
+
len
(
inputs
)),
)
single_fwd
,
single_grads
=
single_jitter
(
*
inputs
,
*
static_inputs
)
# Multi GPUs
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
True
,
mesh_resource
=
mesh_resource
):
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
None
,
"tp"
))
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
"tp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
k2_
=
jax
.
device_put
(
k2
,
k2_sharding
)
if
use_bias
:
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
"tp"
))
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
))
b1_
=
jax
.
device_put
(
b1
,
b1_sharding
)
else
:
b1_sharding
=
b1_
=
None
...
...
@@ -208,7 +181,7 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2
, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
# b2
in_shardings
=
(
None
,
None
,
...
...
@@ -216,14 +189,10 @@ class TestDistributedLayernormMLP:
k2_sharding
,
b1_sharding
,
None
,
None
,
None
,
None
,
None
,
)
out_shardings
=
(
None
,
(
None
,
None
,
k1_sharding
,
k2_sharding
,
b1_sharding
,
None
,
None
,
None
,
None
,
None
),
(
None
,
None
,
k1_sharding
,
k2_sharding
,
b1_sharding
,
None
),
)
multi_jitter
=
jax
.
jit
(
...
...
@@ -245,15 +214,42 @@ class TestDistributedLayernormMLP:
m_grad
,
s_grad
,
dtype
=
dtype
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
)
else
:
is_gated
=
len
(
activation_type
)
>
1
rtol
=
None
atol
=
None
if
is_gated
:
if
dtype
==
jnp
.
bfloat16
:
if
i
==
2
:
rtol
=
800
atol
=
9e-2
if
i
==
4
:
atol
=
300
rtol
=
1e-1
if
dtype
==
jnp
.
float16
:
if
i
==
1
:
# gamma
rtol
=
200
atol
=
1e-2
if
i
==
2
:
rtol
=
2000
atol
=
7e-2
if
i
==
4
and
fp8_recipe
==
recipe
.
MXFP8BlockScaling
():
# bias_1
# Accumulating dbias across a large tensor introduces a larger difference
rtol
=
200
atol
=
4e-2
if
i
==
4
and
fp8_recipe
==
recipe
.
DelayedScaling
():
rtol
=
2200
atol
=
9e-2
assert_allclose
(
multi_grads
[
i
],
single_grads
[
i
],
dtype
=
dtype
,
rtol
=
rtol
,
atol
=
atol
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
def
_test_layernorm_mlp
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
,
fp8_recipe
=
None
):
batch
,
seqlen
,
hidden_in
=
input_shape
layernorm_type
=
"rmsnorm"
...
...
@@ -265,7 +261,7 @@ class TestDistributedLayernormMLP:
init_rngs
=
{
"params"
:
subkeys
[
1
]}
# Single GPUs
with
fp8_autocast
(
enabled
=
use_fp8
):
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
ln_mlp_single
=
LayerNormMLP
(
layernorm_type
=
layernorm_type
,
transpose_batch_sequence
=
False
,
# input: [batch, seqlen, hidden]
...
...
@@ -282,7 +278,9 @@ class TestDistributedLayernormMLP:
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
use_fp8
,
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
ln_mlp_sharded
=
LayerNormMLP
(
layernorm_type
=
layernorm_type
,
transpose_batch_sequence
=
False
,
...
...
@@ -310,25 +308,30 @@ class TestDistributedLayernormMLP:
assert_allclose
(
ln_out_sharded
,
ln_out_single
,
dtype
=
dtype
)
assert_allclose
(
mlp_out_sharded
,
mlp_out_single
,
dtype
=
dtype
)
@
pytest
.
mark
.
parametrize
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)
,
(
"gelu"
,
"gelu"
)
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
_
parametrize
_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
_
parametrize
_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
_
parametrize
_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest
_
parametrize
_wrapper
(
"dtype"
,
DTYPES
)
@
pytest
_
parametrize
_wrapper
(
"use_bias"
,
[
True
,
False
])
def
test_layernorm_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
):
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
False
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
),
(
"gelu"
,
"gelu"
)])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_layernorm_fp8_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
):
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
True
)
# TODO: debug
# @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
# @pytest_parametrize_wrapper(
# "activation_type", [("gelu",), ("gelu", "linear")]
# )
# @pytest_parametrize_wrapper("use_bias", [True, False])
# @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
# @pytest_parametrize_wrapper("dtype", DTYPES)
# @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
# def test_layernorm_fp8_mlp_layer(
# self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
# ):
# self._test_layernorm_mlp(
# mesh_config, activation_type, use_bias, input_shape, dtype,
# use_fp8=True, fp8_recipe=fp8_recipe
# )
tests/jax/test_distributed_softmax.py
View file @
a207db1d
...
...
@@ -3,8 +3,8 @@
# See LICENSE for license information.
import
warnings
import
pytest
from
functools
import
partial
import
pytest
import
jax
import
jax.numpy
as
jnp
...
...
tests/jax/test_helper.py
View file @
a207db1d
...
...
@@ -13,13 +13,13 @@ from utils import assert_allclose
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.common.recipe
import
Format
as
FP8Format
from
transformer_engine.jax
import
fp8_autocast
,
get_delayed_scaling
from
transformer_engine.jax.
fp8
import
FP8Helper
,
is_fp8_available
,
AmaxComputeAlgo
from
transformer_engine.jax.
quantize
import
QuantizeConfig
,
is_fp8_available
,
AmaxComputeAlgo
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
is_fp8_supported
,
reason
=
is_fp8_available
()
class
Test
FP8Helper
(
unittest
.
TestCase
):
class
Test
QuantizeConfig
(
unittest
.
TestCase
):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_initialize
(
self
):
...
...
@@ -27,30 +27,30 @@ class TestFP8Helper(unittest.TestCase):
fp8_format
=
FP8Format
.
E4M3
amax_history_len
=
10
FP8Helper
.
initialize
(
QuantizeConfig
.
initialize
(
margin
=
margin
,
fp8_format
=
fp8_format
,
amax_history_len
=
amax_history_len
)
self
.
assertEqual
(
FP8Helper
.
MARGIN
,
QuantizeConfig
.
MARGIN
,
margin
,
f
"
FP8Helper
.MARGIN initialization failed, should be
{
margin
}
"
f
" but got
{
FP8Helper
.
MARGIN
}
."
,
f
"
QuantizeConfig
.MARGIN initialization failed, should be
{
margin
}
"
f
" but got
{
QuantizeConfig
.
MARGIN
}
."
,
)
self
.
assertEqual
(
FP8Helper
.
FP8_FORMAT
,
QuantizeConfig
.
FP8_FORMAT
,
fp8_format
,
f
"
FP8Helper
.FP8_FORMAT initialization failed, should be
{
fp8_format
}
"
f
" but got
{
FP8Helper
.
FP8_FORMAT
}
."
,
f
"
QuantizeConfig
.FP8_FORMAT initialization failed, should be
{
fp8_format
}
"
f
" but got
{
QuantizeConfig
.
FP8_FORMAT
}
."
,
)
self
.
assertEqual
(
FP8Helper
.
AMAX_HISTORY_LEN
,
QuantizeConfig
.
AMAX_HISTORY_LEN
,
amax_history_len
,
f
"
FP8Helper
.AMAX_HISTORY_LEN initialization failed, should be
{
amax_history_len
}
"
f
" but got
{
FP8Helper
.
AMAX_HISTORY_LEN
}
."
,
f
"
QuantizeConfig
.AMAX_HISTORY_LEN initialization failed, should be
{
amax_history_len
}
"
f
" but got
{
QuantizeConfig
.
AMAX_HISTORY_LEN
}
."
,
)
FP8Helper
.
finalize
()
QuantizeConfig
.
finalize
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_update_collections
(
self
):
...
...
@@ -61,12 +61,12 @@ class TestFP8Helper(unittest.TestCase):
"test1"
:
original_val
,
"test2"
:
original_val
,
}
updated_state
=
FP8Helper
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
updated_state
=
QuantizeConfig
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
original_state
=
flax
.
core
.
frozen_dict
.
FrozenDict
(
original_state
)
updated_state
=
FP8Helper
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
updated_state
=
QuantizeConfig
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
...
...
@@ -74,7 +74,7 @@ class TestFP8Helper(unittest.TestCase):
class
TestFP8Functions
(
unittest
.
TestCase
):
def
_check_defult_state
(
self
):
self
.
assertFalse
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
def
_compare_delay_scaling
(
self
,
ref
,
test
):
self
.
assertTrue
(
ref
.
margin
==
test
.
margin
)
...
...
@@ -84,32 +84,32 @@ class TestFP8Functions(unittest.TestCase):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast
(
self
):
FP8Helper
.
finalize
()
# Ensure the testing not affect by previous tests.
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_defult_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
DelayedScaling
()):
self
.
assertFalse
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
DelayedScaling
())
self
.
_check_defult_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_defult_state
()
ds
=
DelayedScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_defult_state
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast_with_sharding_resource
(
self
):
FP8Helper
.
finalize
()
# Ensure the testing not affect by previous tests.
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_defult_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
...
...
@@ -126,7 +126,7 @@ class TestFP8Functions(unittest.TestCase):
with
jax
.
sharding
.
Mesh
(
devices
,
(
"dp"
,
"tp"
)):
for
sr
in
mesh_s
:
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
,
mesh_resource
=
sr
):
self
.
assertTrue
(
FP8Helper
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
assertEqual
(
sr
,
global_mesh_resource
())
...
...
tests/jax/test_layer.py
View file @
a207db1d
...
...
@@ -20,11 +20,14 @@ from utils import (
from
utils
import
DecoderLayer
as
RefDecoderLayer
from
utils
import
EncoderLayer
as
RefEncoderLayer
from
transformer_engine.common
.recipe
import
Format
from
transformer_engine.common
import
recipe
from
transformer_engine.jax.flax
import
TransformerLayer
,
TransformerLayerType
from
transformer_engine.jax.fp8
import
FP8Helper
,
is_fp8_available
is_fp8_supported
,
reason
=
is_fp8_available
()
from
transformer_engine.jax.quantize
import
(
QuantizeConfig
,
ScalingMode
,
is_fp8_available
,
update_collections
,
)
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"function"
)
...
...
@@ -35,12 +38,21 @@ def enable_fused_attn():
del
os
.
environ
[
"NVTE_FUSED_ATTN"
]
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_MXFP8_1D_SCALING
)
QUANTIZE_RECIPES
=
[]
""" Find supported scaling modes"""
if
is_fp8_supported
:
QUANTIZE_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
if
is_mxfp8_supported
:
QUANTIZE_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
DATA_SHAPE
=
[
# (batch, seqlen, emb_dim)
pytest
.
param
((
32
,
128
,
1024
),
id
=
"32-128-1024"
),
pytest
.
param
((
32
,
512
,
1024
),
id
=
"32-512-1024"
),
]
DTYPE
=
[
jnp
.
float32
,
jnp
.
bfloat16
]
FP8_FORMATS
=
[
Format
.
E4M3
,
Format
.
HYBRID
]
DTYPE
=
[
jnp
.
bfloat16
]
_KEY_OF_RESIDUAL_POST_LAYERNORM
=
"apply_residual_connection_post_layernorm"
_KEY_OF_OUTPUT_LAYERNORM
=
"output_layernorm"
...
...
@@ -80,27 +92,37 @@ BASE_ATTRS = {
}
ATTRS
=
[
# attrs0
{},
# attrs1
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
},
# attrs2
{
_KEY_OF_ZERO_CENTERED_GAMMA
:
True
,
_KEY_OF_LAYERNORM_EPS
:
1e-2
,
},
# attrs3
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_RESIDUAL_POST_LAYERNORM
:
True
},
# attrs4
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_OUTPUT_LAYERNORM
:
True
},
# attrs5
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_RESIDUAL_POST_LAYERNORM
:
True
,
_KEY_OF_OUTPUT_LAYERNORM
:
True
,
},
# attrs6
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_DROP_PATH
:
0.1
},
# attrs7
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_FUSE_QKV_PARAMS
:
False
},
# attrs8
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
},
# attrs9
{
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
...
...
@@ -109,12 +131,14 @@ ATTRS = [
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
_KEY_OF_USE_BIAS
:
True
,
},
# attrs10
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,
"linear"
),
},
# attrs11
{
_KEY_OF_NUM_HEADS
:
8
,
_KEY_OF_NUM_GQA_GROUPS
:
4
,
...
...
@@ -123,33 +147,7 @@ ATTRS = [
_KEY_OF_MLP_ACTIVATIONS
:
(
"gelu"
,),
_KEY_OF_USE_BIAS
:
True
,
},
{
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,
"linear"
)),
},
{
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_HIDDEN_DROPOUT
:
0.8
,
_KEY_OF_INTERMEDIATE_DROPOUT
:
0.5
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,
"linear"
)),
_KEY_OF_USE_BIAS
:
True
,
},
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,
"linear"
)),
},
{
_KEY_OF_NUM_HEADS
:
8
,
_KEY_OF_NUM_GQA_GROUPS
:
4
,
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_SCALE_ATTN_LOGITS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"layernorm"
,
_KEY_OF_MLP_ACTIVATIONS
:
((
"silu"
,)),
_KEY_OF_USE_BIAS
:
True
,
},
# attrs12
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
...
...
@@ -158,12 +156,14 @@ ATTRS = [
_KEY_OF_ROPE_GROUP_METHOD
:
"consecutive"
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs13
{
_KEY_OF_TRANSPOSE_BS
:
True
,
_KEY_OF_ENABLE_ROPE
:
True
,
_KEY_OF_ROPE_GROUP_METHOD
:
"consecutive"
,
_KEY_OF_USE_BIAS
:
True
,
},
# attrs14
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_LAYERNORM_TYPE
:
"layernorm"
,
...
...
@@ -173,6 +173,7 @@ ATTRS = [
_KEY_OF_USE_BIAS
:
True
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs15
{
_KEY_OF_TRANSPOSE_BS
:
True
,
_KEY_OF_LAYERNORM_TYPE
:
"rmsnorm"
,
...
...
@@ -180,26 +181,32 @@ ATTRS = [
_KEY_OF_ROPE_GROUP_METHOD
:
"alternate"
,
_KEY_OF_USE_BIAS
:
True
,
},
# attrs16
{
_KEY_OF_HIDDEN_DROPOUT
:
0.3
,
_KEY_OF_HIDDEN_DROPOUT_DIMS
:
(
0
,),
_KEY_OF_INTERMEDIATE_DROPOUT
:
0.5
,
_KEY_OF_INTERMEDIATE_DROPOUT_DIMS
:
(
1
,),
},
# attrs17
{
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"padding"
,
_KEY_OF_USE_BIAS
:
True
,
},
# attrs18
{
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_BIAS_TYPE
:
"no_bias"
,
},
# attrs19
{
_KEY_OF_ATTENTION_DROPOUT
:
0.3
,
},
# attrs20
{
_KEY_OF_MLP_ACTIVATIONS
:
((
"relu"
,
"relu"
)),
},
# attrs21
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
...
...
@@ -207,6 +214,7 @@ ATTRS = [
_KEY_OF_WINDOW_SIZE
:
(
64
,
0
),
# Left size must < DATA_SHAPE seqlen
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs22
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
...
...
@@ -296,20 +304,24 @@ class BaseRunner:
ref_params
,
test_params
=
self
.
_sync_params
(
ref_params
,
test_params
)
if
FP8Helper
.
is_fp8_enabled
():
if
QuantizeConfig
.
is_fp8_enabled
():
for
_
in
range
(
4
):
_
,
tmp_grad
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
3
,),
has_aux
=
False
)(
_
,
updated_state
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
3
,),
has_aux
=
False
)(
inputs
,
test_masks
,
test_params
,
test_others
,
test_layer
,
)
_
,
fp8_meta_grad
=
flax
.
core
.
pop
(
tmp_grad
[
0
],
FP8Helper
.
FP8_COLLECTION_NAME
)
test_others
=
FP8Helper
.
update_collections
(
{
FP8Helper
.
FP8_COLLECTION_NAME
:
fp8_meta_grad
},
test_others
)
del
tmp_grad
,
fp8_meta_grad
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
_
,
updated_quantize_meta
=
flax
.
core
.
pop
(
updated_state
[
0
],
QuantizeConfig
.
COLLECTION_NAME
)
test_others
=
update_collections
(
{
QuantizeConfig
.
COLLECTION_NAME
:
updated_quantize_meta
},
test_others
)
del
updated_quantize_meta
del
updated_state
grad_fn
=
jax
.
value_and_grad
(
self
.
_loss_fn
,
argnums
=
(
0
,
2
),
has_aux
=
False
)
...
...
@@ -436,29 +448,29 @@ class BaseTester:
def
test_forward
(
self
,
data_shape
,
dtype
,
attrs
):
"""Test normal datatype forward"""
FP8Helper
.
finalize
()
# Ensure FP8 disabled.
QuantizeConfig
.
finalize
()
# Ensure FP8 disabled.
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
)
def
test_backward
(
self
,
data_shape
,
dtype
,
attrs
):
"""Test normal datatype backward"""
FP8Helper
.
finalize
()
# Ensure FP8 disabled.
QuantizeConfig
.
finalize
()
# Ensure FP8 disabled.
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fp8_
format"
,
FP8_FORMAT
S
)
def
test_forward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
format
):
@
pytest
.
mark
.
parametrize
(
"fp8_
recipe"
,
QUANTIZE_RECIPE
S
)
def
test_forward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
recipe
):
"""Test forward with fp8 enabled"""
FP8Helper
.
initialize
(
fp8_
format
=
fp8_format
)
QuantizeConfig
.
initialize
(
fp8_
recipe
=
fp8_recipe
)
self
.
runner
(
attrs
).
test_forward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
FP8Helper
.
finalize
()
QuantizeConfig
.
finalize
()
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"fp8_
format"
,
FP8_FORMAT
S
)
def
test_backward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
format
):
@
pytest
.
mark
.
parametrize
(
"fp8_
recipe"
,
QUANTIZE_RECIPE
S
)
def
test_backward_with_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_
recipe
):
"""Test backward with fp8 enabled"""
FP8Helper
.
initialize
(
fp8_
format
=
fp8_format
)
QuantizeConfig
.
initialize
(
fp8_
recipe
=
fp8_recipe
)
self
.
runner
(
attrs
).
test_backward
(
data_shape
,
dtype
,
rtol
=
1e-4
,
atol
=
1e-3
)
FP8Helper
.
finalize
()
QuantizeConfig
.
finalize
()
class
TestEncoderLayer
(
BaseTester
):
...
...
tests/jax/test_praxis_layers.py
deleted
100644 → 0
View file @
fbee8990
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
os
from
functools
import
partial
from
typing
import
Dict
,
Tuple
import
flax
import
jax
import
jax.numpy
as
jnp
from
praxis
import
pax_fiddle
from
praxis.base_layer
import
WeightInit
,
DEFAULT_INIT_MUTABLE_LIST
import
pytest
from
utils
import
assert_allclose
from
transformer_engine.common.recipe
import
DelayedScaling
,
Format
from
transformer_engine.jax
import
fp8_autocast
,
update_collections
from
transformer_engine.jax.flax
import
DenseGeneral
,
LayerNormDenseGeneral
from
transformer_engine.jax.flax
import
LayerNorm
as
flax_LayerNorm
from
transformer_engine.jax.flax
import
LayerNormMLP
as
flax_LayerNormMLP
from
transformer_engine.jax.flax
import
MultiHeadAttention
as
flax_MultiHeadAttention
from
transformer_engine.jax.flax
import
DotProductAttention
as
flax_DotProductAttention
from
transformer_engine.jax.flax
import
RelativePositionBiases
as
flax_RelativePositionBiases
from
transformer_engine.jax.flax
import
TransformerLayer
as
flax_TransformerLayer
from
transformer_engine.jax.flax.module
import
Softmax
from
transformer_engine.jax.fp8
import
FP8Helper
,
is_fp8_available
from
transformer_engine.jax.praxis
import
LayerNorm
from
transformer_engine.jax.praxis
import
FusedSoftmax
from
transformer_engine.jax.praxis
import
LayerNormLinear
,
LayerNormMLP
,
Linear
from
transformer_engine.jax.praxis
import
DotProductAttention
,
MultiHeadAttention
from
transformer_engine.jax.praxis
import
RelativePositionBiases
,
TransformerEngineBaseLayer
from
transformer_engine.jax.praxis
import
TransformerLayer
,
TransformerLayerType
from
transformer_engine.jax.softmax
import
SoftmaxType
is_fp8_supported
,
reason
=
is_fp8_available
()
DATA_SHAPE
=
[(
32
,
128
,
512
),
(
32
,
512
,
512
)]
# (B, S, H)
DTYPE
=
[
jnp
.
float32
,
jnp
.
bfloat16
]
ENABLE_FP8
=
[
False
,
True
]
FP8_FORMATS
=
[
Format
.
E4M3
,
Format
.
HYBRID
]
def
compare_dict
(
ref_fd
,
test_fd
,
rtol
=
1e-05
,
atol
=
1e-08
):
for
key
in
ref_fd
:
assert
key
in
test_fd
,
f
"
{
key
}
not found in test dict
{
test_fd
}
"
assert
isinstance
(
test_fd
[
key
],
type
(
ref_fd
[
key
])
),
f
"The data type is not match between ref and test Dict on
{
key
=
}
"
if
isinstance
(
ref_fd
[
key
],
Dict
):
compare_dict
(
ref_fd
[
key
],
test_fd
[
key
],
rtol
,
atol
)
else
:
assert_allclose
(
ref_fd
[
key
],
test_fd
[
key
],
rtol
=
rtol
,
atol
=
atol
,
err_msg
=
f
"
{
key
=
}
is not close"
)
class
TestLayer
:
@
staticmethod
def
loss
(
inner_variables
,
*
inner_inputs
,
module
,
mean_out
=
True
):
outs
=
module
.
apply
(
inner_variables
,
*
inner_inputs
)
out
=
outs
if
isinstance
(
outs
,
tuple
):
# The first place of outs is the real output, others
# are auxiliary values.
out
=
outs
[
0
]
return
jnp
.
mean
(
out
)
if
mean_out
else
out
@
staticmethod
def
loss_and_grads
(
module
,
variables
,
*
inputs
):
grad_fn
=
jax
.
value_and_grad
(
TestLayer
.
loss
,
argnums
=
(
0
,
1
))
loss_val
,
(
wgrads
,
dgrad
)
=
grad_fn
(
variables
,
*
inputs
,
module
=
module
)
return
loss_val
,
wgrads
,
dgrad
def
input_getter
(
self
,
shape
,
dtype
):
raise
NotImplementedError
def
get_layer_name
(
self
):
raise
NotImplementedError
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
raise
NotImplementedError
def
sync_variables
(
self
,
praxis_variables
,
flax_variables
):
synced_praxis_variables
=
praxis_variables
lyr_name
=
self
.
get_layer_name
()
if
"params"
in
flax_variables
:
synced_praxis_variables
[
"params"
][
lyr_name
][
"cld"
]
=
flax
.
core
.
unfreeze
(
flax_variables
[
"params"
]
)
return
synced_praxis_variables
,
flax_variables
def
sync_wgrads
(
self
,
praxis_wgrads
,
flax_wgrads
):
synced_praxis_grads
=
praxis_wgrads
lyr_name
=
self
.
get_layer_name
()
if
"params"
in
synced_praxis_grads
:
synced_praxis_grads
[
"params"
]
=
synced_praxis_grads
[
"params"
][
lyr_name
][
"cld"
]
if
FP8Helper
.
is_fp8_enabled
():
synced_praxis_grads
[
FP8Helper
.
FP8_COLLECTION_NAME
]
=
synced_praxis_grads
[
FP8Helper
.
FP8_COLLECTION_NAME
][
lyr_name
][
"cld"
]
return
synced_praxis_grads
,
flax
.
core
.
unfreeze
(
flax_wgrads
)
def
forward_backward_runner
(
self
,
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
=
1e-05
,
atol
=
1e-08
):
init_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
test_inputs
=
self
.
input_getter
(
data_shape
,
dtype
)
praxis_layer
=
praxis_p
.
Instantiate
()
# This is a workaround to correctly enable FP8 meta generation for Praxis.
# TODO (Ming Huang): To come out a better solution.
mutable_list
=
DEFAULT_INIT_MUTABLE_LIST
+
[
FP8Helper
.
FP8_COLLECTION_NAME
]
praxis_variables
=
praxis_layer
.
init
(
init_key
,
*
test_inputs
,
mutable
=
mutable_list
)
flax_layer
=
flax_cls
()
flax_variables
=
flax_layer
.
init
(
init_key
,
*
test_inputs
)
if
"params_axes"
in
flax_variables
:
flax_variables
,
_
=
flax
.
core
.
pop
(
flax_variables
,
"params_axes"
)
if
FP8Helper
.
is_fp8_enabled
():
flax_variables
,
_
=
flax
.
core
.
pop
(
flax_variables
,
FP8Helper
.
FP8_COLLECTION_NAME
+
"_axes"
)
praxis_variables
,
flax_variables
=
self
.
sync_variables
(
praxis_variables
,
flax_variables
)
iter_times
=
5
if
FP8Helper
.
is_fp8_enabled
()
else
1
for
_
in
range
(
iter_times
):
praxis_loss
,
praxis_wgrads
,
praxis_dgrad
=
TestLayer
.
loss_and_grads
(
praxis_layer
,
praxis_variables
,
*
test_inputs
)
flax_loss
,
flax_wgrads
,
flax_dgrad
=
TestLayer
.
loss_and_grads
(
flax_layer
,
flax_variables
,
*
test_inputs
)
if
FP8Helper
.
is_fp8_enabled
():
praxis_wgrads
.
pop
(
"params"
)
praxis_variables
=
update_collections
(
praxis_wgrads
,
praxis_variables
)
flax_wgrads
,
_
=
flax
.
core
.
pop
(
flax_wgrads
,
"params"
)
flax_variables
=
update_collections
(
flax_wgrads
,
flax_variables
)
praxis_loss
,
praxis_wgrads
,
praxis_dgrad
=
TestLayer
.
loss_and_grads
(
praxis_layer
,
praxis_variables
,
*
test_inputs
)
flax_loss
,
flax_wgrads
,
flax_dgrad
=
TestLayer
.
loss_and_grads
(
flax_layer
,
flax_variables
,
*
test_inputs
)
assert_allclose
(
praxis_loss
,
flax_loss
,
rtol
=
rtol
,
atol
=
atol
)
assert_allclose
(
praxis_dgrad
,
flax_dgrad
,
rtol
=
rtol
,
atol
=
atol
)
praxis_wgrads
,
flax_wgrads
=
self
.
sync_wgrads
(
praxis_wgrads
,
flax_wgrads
)
compare_dict
(
praxis_wgrads
,
flax_wgrads
,
rtol
=
rtol
,
atol
=
atol
)
class
LayerNormAttr
:
LN_TYPE
=
"layernorm_type"
ZERO_CEN
=
"zero_centered_gamma"
ATTRS
=
[
{
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
},
{
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
},
{
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
},
]
class
TestLayerNorm
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
data_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
return
(
jax
.
random
.
normal
(
data_key
,
shape
,
dtype
),)
def
get_layer_name
(
self
):
return
"layer_norm"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
layernorm_type
=
attrs
[
LayerNormAttr
.
LN_TYPE
]
zero_centered_gamma
=
attrs
[
LayerNormAttr
.
ZERO_CEN
]
scale_init
=
None
bias_init
=
WeightInit
.
Constant
(
0.0
)
transpose_batch_sequence
=
False
praxis_p
=
pax_fiddle
.
Config
(
LayerNorm
,
name
=
"layer_norm"
,
dtype
=
dtype
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
scale_init
=
scale_init
,
bias_init
=
bias_init
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
flax_cls
=
partial
(
flax_LayerNorm
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
scale_init
=
scale_init
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"ln_bias"
,
bias_init
),
dtype
=
dtype
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LayerNormAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
FusedSoftmaxAttr
:
SCALE_FACTOR
=
"scale_factor"
ST_TYPE
=
"softmax_type"
ATTRS
=
[
{
SCALE_FACTOR
:
0.0
,
ST_TYPE
:
SoftmaxType
.
SCALED
},
{
SCALE_FACTOR
:
0.0
,
ST_TYPE
:
SoftmaxType
.
SCALED_MASKED
},
{
SCALE_FACTOR
:
0.0
,
ST_TYPE
:
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
},
]
class
TestFusedSoftmax
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
data_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
return
jax
.
random
.
normal
(
data_key
,
shape
,
dtype
),
jnp
.
ones
(
shape
,
dtype
=
jnp
.
uint8
)
# Masks
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
scale_factor
=
attrs
[
FusedSoftmaxAttr
.
SCALE_FACTOR
]
softmax_type
=
attrs
[
FusedSoftmaxAttr
.
ST_TYPE
]
praxis_p
=
pax_fiddle
.
Config
(
FusedSoftmax
,
name
=
"fused_softmax"
,
scale_factor
=
scale_factor
,
softmax_type
=
softmax_type
)
flax_cls
=
partial
(
Softmax
,
scale_factor
=
scale_factor
,
softmax_type
=
softmax_type
)
return
praxis_p
,
flax_cls
def
sync_variables
(
self
,
praxis_variables
,
flax_variables
):
return
praxis_variables
,
flax_variables
def
sync_wgrads
(
self
,
praxis_wgrads
,
flax_wgrads
):
return
praxis_wgrads
,
flax_wgrads
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[(
32
,
1
,
128
,
128
),
(
32
,
1
,
512
,
128
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
FusedSoftmaxAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
if
(
attrs
[
FusedSoftmaxAttr
.
ST_TYPE
]
==
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
)
and
(
data_shape
[
-
2
]
!=
data_shape
[
-
1
]
):
pass
# Skip, due to not support
else
:
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
LinearAttr
:
FEATURE
=
"features"
USE_BIAS
=
"use_bias"
ATTRS
=
[
{
FEATURE
:
512
,
USE_BIAS
:
False
},
{
FEATURE
:
512
,
USE_BIAS
:
True
},
{
FEATURE
:
1024
,
USE_BIAS
:
False
},
{
FEATURE
:
1024
,
USE_BIAS
:
True
},
]
class
TestLinear
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
data_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
return
(
jax
.
random
.
normal
(
data_key
,
shape
,
dtype
),)
def
get_layer_name
(
self
):
return
"linear"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
out_features
=
attrs
[
LinearAttr
.
FEATURE
]
kernel_init
=
WeightInit
.
Gaussian
(
1.0
)
use_bias
=
attrs
[
LinearAttr
.
USE_BIAS
]
bias_init
=
WeightInit
.
Constant
(
0.0
)
axis
=
-
1
transpose_batch_sequence
=
False
praxis_p
=
pax_fiddle
.
Config
(
Linear
,
name
=
"linear"
,
dtype
=
dtype
,
out_features
=
out_features
,
params_init
=
kernel_init
,
use_bias
=
use_bias
,
bias_init
=
bias_init
,
axis
=
axis
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
flax_cls
=
partial
(
DenseGeneral
,
features
=
out_features
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
kernel_init
),
use_bias
=
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
bias_init
),
axis
=
axis
,
dtype
=
dtype
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LinearAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LinearAttr
.
ATTRS
)
@
pytest
.
mark
.
parametrize
(
"fp8_format"
,
FP8_FORMATS
)
def
test_forward_backward_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_format
,
rtol
=
1e-05
,
atol
=
1e-08
):
ds
=
DelayedScaling
(
fp8_format
=
fp8_format
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
LayerNormLinearAttr
:
FEATURE
=
"features"
USE_BIAS
=
"use_bias"
ENABLE_LN
=
"enable_layernorm"
LN_TYPE
=
"layernorm_type"
ZERO_CEN
=
"zero_centered_gamma"
ATTRS
=
[
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
},
{
FEATURE
:
512
,
USE_BIAS
:
True
,
ENABLE_LN
:
False
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
},
]
class
TestLayerNormLinear
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
data_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
return
(
jax
.
random
.
normal
(
data_key
,
shape
,
dtype
),)
def
get_layer_name
(
self
):
return
"ln_linear"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
out_features
=
attrs
[
LayerNormLinearAttr
.
FEATURE
]
enable_layernorm
=
attrs
[
LayerNormLinearAttr
.
ENABLE_LN
]
layernorm_type
=
attrs
[
LayerNormLinearAttr
.
LN_TYPE
]
zero_centered_gamma
=
attrs
[
LayerNormLinearAttr
.
ZERO_CEN
]
kernel_init
=
WeightInit
.
Gaussian
(
1.0
)
use_bias
=
attrs
[
LayerNormLinearAttr
.
USE_BIAS
]
bias_init
=
WeightInit
.
Constant
(
0.0
)
axis
=
-
1
transpose_batch_sequence
=
False
praxis_p
=
pax_fiddle
.
Config
(
LayerNormLinear
,
name
=
"ln_linear"
,
dtype
=
dtype
,
out_features
=
out_features
,
enable_layernorm
=
enable_layernorm
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
params_init
=
kernel_init
,
use_bias
=
use_bias
,
bias_init
=
bias_init
,
axis
=
axis
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
flax_cls
=
partial
(
LayerNormDenseGeneral
,
features
=
out_features
,
enable_layernorm
=
enable_layernorm
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
kernel_init
),
use_bias
=
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
bias_init
),
axis
=
axis
,
dtype
=
dtype
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LayerNormLinearAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LayerNormLinearAttr
.
ATTRS
)
@
pytest
.
mark
.
parametrize
(
"fp8_format"
,
FP8_FORMATS
)
def
test_forward_backward_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_format
,
rtol
=
1e-05
,
atol
=
1e-08
):
ds
=
DelayedScaling
(
fp8_format
=
fp8_format
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
LayerNormMLPAttr
:
INTERMEDIATE_DIM
=
"intermediate_dim"
USE_BIAS
=
"use_bias"
ENABLE_LN
=
"enable_layernorm"
LN_TYPE
=
"layernorm_type"
ZERO_CEN
=
"zero_centered_gamma"
ACTIVATION
=
"activations"
ATTRS
=
[
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"relu"
,),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
False
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
True
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"silu"
,
"linear"
),
},
{
INTERMEDIATE_DIM
:
2048
,
USE_BIAS
:
False
,
ENABLE_LN
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"silu"
,
"linear"
),
},
]
class
TestLayerNormMLP
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
data_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
return
(
jax
.
random
.
normal
(
data_key
,
shape
,
dtype
),)
def
get_layer_name
(
self
):
return
"ln_mlp"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
intermediate_dim
=
attrs
[
LayerNormMLPAttr
.
INTERMEDIATE_DIM
]
enable_layernorm
=
attrs
[
LayerNormMLPAttr
.
ENABLE_LN
]
layernorm_type
=
attrs
[
LayerNormMLPAttr
.
LN_TYPE
]
zero_centered_gamma
=
attrs
[
LayerNormMLPAttr
.
ZERO_CEN
]
kernel_init
=
WeightInit
.
Gaussian
(
1.0
)
use_bias
=
attrs
[
LayerNormMLPAttr
.
USE_BIAS
]
bias_init
=
WeightInit
.
Constant
(
0.0
)
activations
=
attrs
[
LayerNormMLPAttr
.
ACTIVATION
]
axis
=
-
1
transpose_batch_sequence
=
False
praxis_p
=
pax_fiddle
.
Config
(
LayerNormMLP
,
name
=
"ln_mlp"
,
dtype
=
dtype
,
intermediate_dim
=
intermediate_dim
,
enable_layernorm
=
enable_layernorm
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
params_init
=
kernel_init
,
use_bias
=
use_bias
,
bias_init
=
bias_init
,
activations
=
activations
,
intermediate_dropout_rate
=
0.0
,
axis
=
axis
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
flax_cls
=
partial
(
flax_LayerNormMLP
,
intermediate_dim
=
intermediate_dim
,
enable_layernorm
=
enable_layernorm
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
kernel_init
),
use_bias
=
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
bias_init
),
activations
=
activations
,
intermediate_dropout_rate
=
0.0
,
axis
=
axis
,
dtype
=
dtype
,
transpose_batch_sequence
=
transpose_batch_sequence
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LayerNormMLPAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
LayerNormMLPAttr
.
ATTRS
)
@
pytest
.
mark
.
parametrize
(
"fp8_format"
,
FP8_FORMATS
)
def
test_forward_backward_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_format
,
rtol
=
1e-05
,
atol
=
1e-08
):
ds
=
DelayedScaling
(
fp8_format
=
fp8_format
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
TestRelativePositionBias
(
TestLayer
):
def
get_layer_name
(
self
):
return
"relative_position_bias"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
num_buckets
=
32
max_distance
=
128
num_attention_heads
=
64
rb_stddev
=
(
num_attention_heads
*
num_buckets
)
**
-
0.5
embedding_init
=
WeightInit
.
Gaussian
(
rb_stddev
)
praxis_p
=
pax_fiddle
.
Config
(
RelativePositionBiases
,
name
=
"relative_position_bias"
,
dtype
=
dtype
,
num_buckets
=
num_buckets
,
max_distance
=
max_distance
,
num_attention_heads
=
num_attention_heads
,
embedding_init
=
embedding_init
,
)
flax_cls
=
partial
(
flax_RelativePositionBiases
,
num_buckets
=
num_buckets
,
max_distance
=
max_distance
,
num_attention_heads
=
num_attention_heads
,
embedding_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"rel_embedding"
,
embedding_init
),
dtype
=
dtype
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
[{}])
def
test_forward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
init_key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
test_inputs
=
[(
128
,
128
,
True
),
(
128
,
128
,
False
)]
for
test_input
in
test_inputs
:
praxis_layer
=
praxis_p
.
Instantiate
()
praxis_variables
=
praxis_layer
.
init
(
init_key
,
*
test_input
)
flax_layer
=
flax_cls
()
flax_variables
=
flax_layer
.
init
(
init_key
,
*
test_input
)
if
"params_axes"
in
flax_variables
:
flax_variables
,
_
=
flax
.
core
.
pop
(
flax_variables
,
"params_axes"
)
if
FP8Helper
.
is_fp8_enabled
():
flax_variables
,
_
=
flax
.
core
.
pop
(
flax_variables
,
FP8Helper
.
FP8_COLLECTION_NAME
+
"_axes"
)
praxis_variables
,
flax_variables
=
self
.
sync_variables
(
praxis_variables
,
flax_variables
)
praxis_loss
=
TestLayer
.
loss
(
praxis_variables
,
*
test_input
,
module
=
praxis_layer
,
mean_out
=
False
)
flax_loss
=
TestLayer
.
loss
(
flax_variables
,
*
test_input
,
module
=
flax_layer
,
mean_out
=
False
)
assert_allclose
(
praxis_loss
,
flax_loss
,
rtol
=
rtol
,
atol
=
atol
)
class
DotProductAttnAttr
:
ATTN_MASK_TYPE
=
"attn_mask_type"
NUM_GQA_GROUPS
=
"num_gqa_groups"
TRANSPOSE_BS
=
"transpose_batch_sequence"
SCALE_FACTOR
=
"scale_factor"
WINDOW_SIZE
=
"window_size"
ATTRS
=
[
{
ATTN_MASK_TYPE
:
"padding"
,
TRANSPOSE_BS
:
True
,
SCALE_FACTOR
:
0.125
,
},
{
ATTN_MASK_TYPE
:
"padding_causal"
,
TRANSPOSE_BS
:
True
,
SCALE_FACTOR
:
0.125
,
},
{
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
True
,
SCALE_FACTOR
:
0.125
,
},
{
ATTN_MASK_TYPE
:
"padding"
,
TRANSPOSE_BS
:
False
,
SCALE_FACTOR
:
0.125
,
},
{
ATTN_MASK_TYPE
:
"padding_causal"
,
TRANSPOSE_BS
:
False
,
SCALE_FACTOR
:
2.0
,
},
{
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
False
,
SCALE_FACTOR
:
1.0
,
},
{
ATTN_MASK_TYPE
:
"no_mask"
,
TRANSPOSE_BS
:
False
,
SCALE_FACTOR
:
1.0
,
},
{
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
False
,
SCALE_FACTOR
:
1.0
,
WINDOW_SIZE
:
(
64
,
0
),
# Left size must <= S in DATA_SHAPE
},
]
class
TestDotProductAttn
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
q_key
,
k_key
,
v_key
=
jax
.
random
.
split
(
key
,
3
)
b
,
s
,
*
_
=
shape
if
self
.
attrs
[
DotProductAttnAttr
.
TRANSPOSE_BS
]:
shape
=
(
shape
[
1
],
shape
[
0
])
+
shape
[
2
:]
mask
=
jnp
.
zeros
((
b
,
1
,
s
,
s
),
dtype
=
jnp
.
uint8
)
return
[
*
map
(
partial
(
jax
.
random
.
normal
,
shape
=
shape
,
dtype
=
dtype
),
[
q_key
,
k_key
,
v_key
]),
mask
,
]
def
get_layer_name
(
self
):
return
"dot_product_attn"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
head_dim
=
64
num_attention_heads
=
16
num_gqa_groups
=
num_attention_heads
attn_mask_type
=
attrs
[
DotProductAttnAttr
.
ATTN_MASK_TYPE
]
transpose_batch_sequence
=
attrs
[
DotProductAttnAttr
.
TRANSPOSE_BS
]
window_size
=
attrs
.
get
(
DotProductAttnAttr
.
WINDOW_SIZE
,
None
)
praxis_p
=
pax_fiddle
.
Config
(
DotProductAttention
,
name
=
"mha"
,
dtype
=
dtype
,
head_dim
=
head_dim
,
num_attention_heads
=
num_attention_heads
,
num_gqa_groups
=
num_gqa_groups
,
attn_mask_type
=
attn_mask_type
,
transpose_batch_sequence
=
transpose_batch_sequence
,
window_size
=
window_size
,
)
flax_cls
=
partial
(
flax_DotProductAttention
,
dtype
=
dtype
,
head_dim
=
head_dim
,
num_attention_heads
=
num_attention_heads
,
num_gqa_groups
=
num_gqa_groups
,
attn_mask_type
=
attn_mask_type
,
transpose_batch_sequence
=
transpose_batch_sequence
,
window_size
=
window_size
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[(
32
,
128
,
16
,
64
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
DotProductAttnAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
self
.
attrs
=
attrs
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
MultiHeadAttnAttr
:
USE_BIAS
=
"use_bias"
LN_TYPE
=
"layernorm_type"
ATTN_MASK_TYPE
=
"attn_mask_type"
ZERO_CEN
=
"zero_centered_gamma"
NUM_ATTN_HEADS
=
"num_attention_heads"
NUM_GQA_GROUPS
=
"num_gqa_groups"
TRANSPOSE_BS
=
"transpose_batch_sequence"
ENABLE_ROPE
=
"enable_rotary_pos_emb"
ROPE_GROUP_METHOD
=
"rotary_pos_emb_group_method"
LORA_SCOPE
=
"low_rank_adaptation_scope"
WINDOW_SIZE
=
"window_size"
ATTRS
=
[
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"padding"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"padding"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"padding"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
NUM_ATTN_HEADS
:
8
,
NUM_GQA_GROUPS
:
4
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"consecutive"
,
NUM_ATTN_HEADS
:
8
,
NUM_GQA_GROUPS
:
4
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"alternate"
,
NUM_ATTN_HEADS
:
8
,
NUM_GQA_GROUPS
:
4
,
ATTN_MASK_TYPE
:
"causal"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"padding"
,
LORA_SCOPE
:
"all"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"causal"
,
LORA_SCOPE
:
"all"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
ATTN_MASK_TYPE
:
"causal"
,
LORA_SCOPE
:
"all"
,
TRANSPOSE_BS
:
True
,
WINDOW_SIZE
:
(
64
,
0
),
# Left size must <= S in DATA_SHAPE
},
]
class
TestMultiHeadAttn
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
q_key
,
kv_key
=
jax
.
random
.
split
(
key
,
2
)
b
,
s
,
*
_
=
shape
if
self
.
attrs
[
MultiHeadAttnAttr
.
TRANSPOSE_BS
]:
shape
=
(
shape
[
1
],
shape
[
0
])
+
shape
[
2
:]
mask
=
jnp
.
zeros
((
b
,
1
,
s
,
s
),
dtype
=
jnp
.
uint8
)
return
[
*
map
(
partial
(
jax
.
random
.
normal
,
shape
=
shape
,
dtype
=
dtype
),
[
q_key
,
kv_key
]),
mask
]
def
get_layer_name
(
self
):
return
"multi_head_attn"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
head_dim
=
64
num_attention_heads
=
16
num_gqa_groups
=
(
attrs
[
MultiHeadAttnAttr
.
NUM_GQA_GROUPS
]
if
MultiHeadAttnAttr
.
NUM_GQA_GROUPS
in
attrs
else
None
)
layernorm_type
=
attrs
[
MultiHeadAttnAttr
.
LN_TYPE
]
zero_centered_gamma
=
attrs
[
MultiHeadAttnAttr
.
ZERO_CEN
]
kernel_init
=
WeightInit
.
Gaussian
(
1.0
)
use_bias
=
attrs
[
MultiHeadAttnAttr
.
USE_BIAS
]
bias_init
=
WeightInit
.
Constant
(
0.0
)
input_layernorm
=
False
return_layernorm_output
=
False
attn_mask_type
=
attrs
[
MultiHeadAttnAttr
.
ATTN_MASK_TYPE
]
enable_rotary_pos_emb
=
attrs
[
MultiHeadAttnAttr
.
ENABLE_ROPE
]
rotary_pos_emb_group_method
=
attrs
[
MultiHeadAttnAttr
.
ROPE_GROUP_METHOD
]
low_rank_adaptation_scope
=
attrs
.
get
(
MultiHeadAttnAttr
.
LORA_SCOPE
,
"none"
)
fuse_qkv_params
=
True
transpose_batch_sequence
=
attrs
[
MultiHeadAttnAttr
.
TRANSPOSE_BS
]
scale_attn_logits
=
False
scaled_query_init
=
True
float32_logits
=
False
window_size
=
attrs
.
get
(
MultiHeadAttnAttr
.
WINDOW_SIZE
,
None
)
praxis_p
=
pax_fiddle
.
Config
(
MultiHeadAttention
,
name
=
"mha"
,
dtype
=
dtype
,
head_dim
=
head_dim
,
num_attention_heads
=
num_attention_heads
,
num_gqa_groups
=
num_gqa_groups
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
params_init
=
kernel_init
,
use_bias
=
use_bias
,
bias_init
=
bias_init
,
return_layernorm_output
=
return_layernorm_output
,
input_layernorm
=
input_layernorm
,
attn_mask_type
=
attn_mask_type
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
fuse_qkv_params
=
fuse_qkv_params
,
transpose_batch_sequence
=
transpose_batch_sequence
,
scale_attn_logits
=
scale_attn_logits
,
scaled_query_init
=
scaled_query_init
,
float32_logits
=
float32_logits
,
window_size
=
window_size
,
)
flax_cls
=
partial
(
flax_MultiHeadAttention
,
dtype
=
dtype
,
head_dim
=
head_dim
,
num_attention_heads
=
num_attention_heads
,
num_gqa_groups
=
num_gqa_groups
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
kernel_init
),
use_bias
=
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
bias_init
),
return_layernorm_output
=
return_layernorm_output
,
input_layernorm
=
input_layernorm
,
attn_mask_type
=
attn_mask_type
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
fuse_qkv_params
=
fuse_qkv_params
,
transpose_batch_sequence
=
transpose_batch_sequence
,
scale_attn_logits
=
scale_attn_logits
,
scaled_query_init
=
scaled_query_init
,
float32_logits
=
float32_logits
,
window_size
=
window_size
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
MultiHeadAttnAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
self
.
attrs
=
attrs
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
MultiHeadAttnAttr
.
ATTRS
)
@
pytest
.
mark
.
parametrize
(
"fp8_format"
,
FP8_FORMATS
)
def
test_forward_backward_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_format
,
rtol
=
1e-05
,
atol
=
1e-08
):
self
.
attrs
=
attrs
ds
=
DelayedScaling
(
fp8_format
=
fp8_format
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
class
TransformerLayerAttr
:
USE_BIAS
=
"use_bias"
LN_TYPE
=
"layernorm_type"
ACTIVATION
=
"activations"
LYR_TYPE
=
"layer_type"
ZERO_CEN
=
"zero_centered_gamma"
TRANSPOSE_BS
=
"transpose_batch_sequence"
ENABLE_ROPE
=
"enable_rotary_pos_emb"
ROPE_GROUP_METHOD
=
"rotary_pos_emb_group_method"
LORA_SCOPE
=
"low_rank_adaptation_scope"
WINDOW_SIZE
=
"window_size"
ATTRS
=
[
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
LORA_SCOPE
:
"all"
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
True
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"rmsnorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,
"linear"
),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"alternate"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"alternate"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
True
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
True
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"gelu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
LORA_SCOPE
:
"all"
,
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
ENCODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
WINDOW_SIZE
:
(
64
,
0
),
# Left size must <= S in DATA_SHAPE
},
{
USE_BIAS
:
True
,
LN_TYPE
:
"layernorm"
,
ZERO_CEN
:
False
,
ACTIVATION
:
(
"relu"
,),
LYR_TYPE
:
TransformerLayerType
.
DECODER
,
ENABLE_ROPE
:
False
,
ROPE_GROUP_METHOD
:
"consecutive"
,
TRANSPOSE_BS
:
False
,
WINDOW_SIZE
:
(
64
,
0
),
# Left size must <= S in DATA_SHAPE
},
]
class
TestTransformer
(
TestLayer
):
def
input_getter
(
self
,
shape
,
dtype
):
key
=
jax
.
random
.
PRNGKey
(
seed
=
1234
)
q_key
,
kv_key
=
jax
.
random
.
split
(
key
,
2
)
b
,
s
,
*
_
=
shape
if
self
.
attrs
[
TransformerLayerAttr
.
TRANSPOSE_BS
]:
shape
=
(
shape
[
1
],
shape
[
0
])
+
shape
[
2
:]
mask
=
jnp
.
zeros
((
b
,
1
,
s
,
s
),
dtype
=
jnp
.
uint8
)
return
[
*
map
(
partial
(
jax
.
random
.
normal
,
shape
=
shape
,
dtype
=
dtype
),
[
q_key
,
kv_key
]),
mask
,
mask
,
]
def
get_layer_name
(
self
):
return
"transformerlayer"
def
generate_praxis_p_and_flax_cls
(
self
,
dtype
,
attrs
):
hidden_size
=
512
mlp_hidden_size
=
2048
num_attention_heads
=
8
layernorm_type
=
attrs
[
TransformerLayerAttr
.
LN_TYPE
]
hidden_dropout
=
0.0
attention_dropout
=
0.0
intermediate_dropout
=
0.0
mlp_activations
=
attrs
[
TransformerLayerAttr
.
ACTIVATION
]
kernel_init
=
WeightInit
.
Gaussian
(
1.0
)
use_bias
=
attrs
[
TransformerLayerAttr
.
USE_BIAS
]
bias_init
=
WeightInit
.
Constant
(
0.0
)
layer_type
=
attrs
[
TransformerLayerAttr
.
LYR_TYPE
]
enable_rotary_pos_emb
=
attrs
[
TransformerLayerAttr
.
ENABLE_ROPE
]
rotary_pos_emb_group_method
=
attrs
[
TransformerLayerAttr
.
ROPE_GROUP_METHOD
]
low_rank_adaptation_scope
=
attrs
.
get
(
TransformerLayerAttr
.
LORA_SCOPE
,
"none"
)
enable_relative_embedding
=
True
relative_embedding
=
pax_fiddle
.
Config
(
RelativePositionBiases
,
dtype
=
dtype
,
num_attention_heads
=
num_attention_heads
)
drop_path
=
0.0
transpose_batch_sequence
=
attrs
[
TransformerLayerAttr
.
TRANSPOSE_BS
]
window_size
=
attrs
.
get
(
TransformerLayerAttr
.
WINDOW_SIZE
,
None
)
rel_embedding_init
=
RelativePositionBiases
.
generate_embedding_init
(
relative_embedding
.
embedding_init
,
relative_embedding
.
num_attention_heads
,
relative_embedding
.
num_buckets
,
)
relative_embedding_flax_module
=
flax_RelativePositionBiases
(
num_buckets
=
relative_embedding
.
num_buckets
,
max_distance
=
relative_embedding
.
max_distance
,
num_attention_heads
=
relative_embedding
.
num_attention_heads
,
embedding_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"rel_embedding"
,
rel_embedding_init
),
embedding_axes
=
relative_embedding
.
embedding_axes
,
dtype
=
relative_embedding
.
dtype
,
)
praxis_p
=
pax_fiddle
.
Config
(
TransformerLayer
,
name
=
"transformer_layer"
,
params_init
=
kernel_init
,
dtype
=
dtype
,
hidden_size
=
hidden_size
,
mlp_hidden_size
=
mlp_hidden_size
,
num_attention_heads
=
num_attention_heads
,
layernorm_type
=
layernorm_type
,
hidden_dropout
=
hidden_dropout
,
attention_dropout
=
attention_dropout
,
intermediate_dropout
=
intermediate_dropout
,
mlp_activations
=
mlp_activations
,
use_bias
=
use_bias
,
bias_init
=
bias_init
,
layer_type
=
layer_type
,
enable_relative_embedding
=
enable_relative_embedding
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
relative_embedding
=
relative_embedding
,
drop_path
=
drop_path
,
transpose_batch_sequence
=
transpose_batch_sequence
,
window_size
=
window_size
,
)
flax_cls
=
partial
(
flax_TransformerLayer
,
dtype
=
dtype
,
hidden_size
=
hidden_size
,
mlp_hidden_size
=
mlp_hidden_size
,
num_attention_heads
=
num_attention_heads
,
layernorm_type
=
layernorm_type
,
hidden_dropout
=
hidden_dropout
,
attention_dropout
=
attention_dropout
,
intermediate_dropout
=
intermediate_dropout
,
mlp_activations
=
mlp_activations
,
mha_kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"mha_kernel"
,
kernel_init
),
mlp_kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"mlp_kernel"
,
kernel_init
),
use_bias
=
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
bias_init
),
layer_type
=
layer_type
,
enable_rotary_pos_emb
=
enable_rotary_pos_emb
,
rotary_pos_emb_group_method
=
rotary_pos_emb_group_method
,
enable_relative_embedding
=
enable_relative_embedding
,
relative_embedding
=
relative_embedding_flax_module
,
low_rank_adaptation_scope
=
low_rank_adaptation_scope
,
drop_path
=
drop_path
,
transpose_batch_sequence
=
transpose_batch_sequence
,
window_size
=
window_size
,
)
return
praxis_p
,
flax_cls
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
TransformerLayerAttr
.
ATTRS
)
def
test_forward_backward
(
self
,
data_shape
,
dtype
,
attrs
,
rtol
=
1e-05
,
atol
=
1e-08
):
self
.
attrs
=
attrs
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DATA_SHAPE
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPE
)
@
pytest
.
mark
.
parametrize
(
"attrs"
,
TransformerLayerAttr
.
ATTRS
)
@
pytest
.
mark
.
parametrize
(
"fp8_format"
,
FP8_FORMATS
)
def
test_forward_backward_fp8
(
self
,
data_shape
,
dtype
,
attrs
,
fp8_format
,
rtol
=
1e-05
,
atol
=
1e-08
):
self
.
attrs
=
attrs
ds
=
DelayedScaling
(
fp8_format
=
fp8_format
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
praxis_p
,
flax_cls
=
self
.
generate_praxis_p_and_flax_cls
(
dtype
,
attrs
)
self
.
forward_backward_runner
(
data_shape
,
dtype
,
praxis_p
,
flax_cls
,
rtol
,
atol
)
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment