Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f8c2af4c
Commit
f8c2af4c
authored
May 21, 2025
by
yuguo
Browse files
Merge commit '
1d903f5e
' of...
Merge commit '
1d903f5e
' of
https://github.com/NVIDIA/TransformerEngine
parents
e92773a3
1d903f5e
Changes
211
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
548 additions
and
362 deletions
+548
-362
tests/cpp/operator/test_normalization.cu
tests/cpp/operator/test_normalization.cu
+10
-10
tests/cpp/operator/test_normalization_mxfp8.cu
tests/cpp/operator/test_normalization_mxfp8.cu
+7
-7
tests/cpp/operator/test_qdq.cu
tests/cpp/operator/test_qdq.cu
+4
-4
tests/cpp/operator/test_transpose.cu
tests/cpp/operator/test_transpose.cu
+2
-2
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+6
-11
tests/cpp/test_common.h
tests/cpp/test_common.h
+2
-0
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+138
-35
tests/jax/test_distributed_layernorm.py
tests/jax/test_distributed_layernorm.py
+3
-0
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+1
-27
tests/jax/test_helper.py
tests/jax/test_helper.py
+78
-46
tests/jax/utils.py
tests/jax/utils.py
+49
-41
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
+27
-14
tests/pytorch/distributed/run_gemm_with_overlap.py
tests/pytorch/distributed/run_gemm_with_overlap.py
+79
-29
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+9
-2
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+1
-1
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
+7
-2
tests/pytorch/distributed/test_comm_gemm_overlap.py
tests/pytorch/distributed/test_comm_gemm_overlap.py
+44
-73
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+65
-37
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
+6
-2
tests/pytorch/fused_attn/test_fused_attn.py
tests/pytorch/fused_attn/test_fused_attn.py
+10
-19
No files found.
tests/cpp/operator/test_normalization.cu
View file @
f8c2af4c
...
...
@@ -49,16 +49,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return
;
}
Tensor
input
(
"input"
,
{
N
,
H
},
itype
);
Tensor
z
(
"z"
,
{
N
,
H
},
otype
);
Tensor
gamma
(
"gamma"
,
{
H
},
wtype
);
Tensor
beta
(
"beta"
,
{
H
},
wtype
);
Tensor
mu
(
"mu"
,
{
N
},
DType
::
kFloat32
);
Tensor
rsigma
(
"rsigma"
,
{
N
},
DType
::
kFloat32
);
Tensor
dz
(
"dz"
,
{
N
,
H
},
wtype
);
Tensor
dx
(
"dx"
,
{
N
,
H
},
itype
);
Tensor
dgamma
(
"dgamma"
,
{
H
},
wtype
);
Tensor
dbeta
(
"dbeta"
,
{
H
},
wtype
);
Tensor
input
(
"input"
,
std
::
vector
<
size_t
>
{
N
,
H
},
itype
);
Tensor
z
(
"z"
,
std
::
vector
<
size_t
>
{
N
,
H
},
otype
);
Tensor
gamma
(
"gamma"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
beta
(
"beta"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
mu
(
"mu"
,
std
::
vector
<
size_t
>
{
N
},
DType
::
kFloat32
);
Tensor
rsigma
(
"rsigma"
,
std
::
vector
<
size_t
>
{
N
},
DType
::
kFloat32
);
Tensor
dz
(
"dz"
,
std
::
vector
<
size_t
>
{
N
,
H
},
wtype
);
Tensor
dx
(
"dx"
,
std
::
vector
<
size_t
>
{
N
,
H
},
itype
);
Tensor
dgamma
(
"dgamma"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
dbeta
(
"dbeta"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
workspace_fwd
,
workspace_bwd
;
fillUniform
(
&
input
);
...
...
tests/cpp/operator/test_normalization_mxfp8.cu
View file @
f8c2af4c
...
...
@@ -116,12 +116,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
DType
wtype
=
TypeInfo
<
WeightType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
Tensor
input
(
"input"
,
{
N
,
H
},
itype
);
Tensor
z
(
"z"
,
{
N
,
H
},
otype
,
true
,
is_training
,
NVTE_MXFP8_1D_SCALING
);
Tensor
gamma
(
"gamma"
,
{
H
},
wtype
);
Tensor
beta
(
"beta"
,
{
H
},
wtype
);
Tensor
mu
(
"mu"
,
{
N
},
DType
::
kFloat32
);
Tensor
rsigma
(
"rsigma"
,
{
N
},
DType
::
kFloat32
);
Tensor
input
(
"input"
,
std
::
vector
<
size_t
>
{
N
,
H
},
itype
);
Tensor
z
(
"z"
,
std
::
vector
<
size_t
>
{
N
,
H
},
otype
,
true
,
is_training
,
NVTE_MXFP8_1D_SCALING
);
Tensor
gamma
(
"gamma"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
beta
(
"beta"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
mu
(
"mu"
,
std
::
vector
<
size_t
>
{
N
},
DType
::
kFloat32
);
Tensor
rsigma
(
"rsigma"
,
std
::
vector
<
size_t
>
{
N
},
DType
::
kFloat32
);
Tensor
workspace
;
...
...
@@ -164,7 +164,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
}
Tensor
dequantized_output
(
"dequantized_output"
,
{
N
,
H
},
DType
::
kFloat32
,
true
,
true
);
Tensor
dequantized_output
(
"dequantized_output"
,
std
::
vector
<
size_t
>
{
N
,
H
},
DType
::
kFloat32
,
true
,
true
);
dequantize_2x
<
OutputType
,
fp8e8m0
>
(
z
,
dequantized_output
,
is_training
);
...
...
tests/cpp/operator/test_qdq.cu
View file @
f8c2af4c
...
...
@@ -58,8 +58,8 @@ void performTestQ(const size_t N) {
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
Tensor
input
(
"input"
,
{
N
},
itype
);
Tensor
output
(
"output"
,
{
N
},
otype
);
Tensor
input
(
"input"
,
std
::
vector
<
size_t
>
{
N
},
itype
);
Tensor
output
(
"output"
,
std
::
vector
<
size_t
>
{
N
},
otype
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output
=
std
::
make_unique
<
OutputType
[]
>
(
N
);
...
...
@@ -89,8 +89,8 @@ void performTestDQ(const size_t N) {
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
Tensor
input
(
"input"
,
{
N
},
itype
);
Tensor
output
(
"output"
,
{
N
},
otype
);
Tensor
input
(
"input"
,
std
::
vector
<
size_t
>
{
N
},
itype
);
Tensor
output
(
"output"
,
std
::
vector
<
size_t
>
{
N
},
otype
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output
=
std
::
make_unique
<
OutputType
[]
>
(
N
);
...
...
tests/cpp/operator/test_transpose.cu
View file @
f8c2af4c
...
...
@@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) {
DType
dtype
=
TypeInfo
<
Type
>::
dtype
;
Tensor
input
(
"input"
,
{
N
,
H
},
dtype
);
Tensor
output
(
"output"
,
{
H
,
N
},
dtype
);
Tensor
input
(
"input"
,
std
::
vector
<
size_t
>
{
N
,
H
},
dtype
);
Tensor
output
(
"output"
,
std
::
vector
<
size_t
>
{
H
,
N
},
dtype
);
std
::
unique_ptr
<
Type
[]
>
ref_output
=
std
::
make_unique
<
Type
[]
>
(
N
*
H
);
...
...
tests/cpp/test_common.cu
View file @
f8c2af4c
...
...
@@ -783,8 +783,6 @@ void fillUniform(Tensor *t) {
template
<
typename
InputEncoding
,
InputsFillCase
Case
>
void
fillCase_special
(
Tensor
*
t
)
{
const
size_t
size
=
product
(
t
->
rowwise_shape
());
const
size_t
rows
=
t
->
rowwise_shape
().
data
[
0
];
const
size_t
cols
=
t
->
rowwise_shape
().
data
[
1
];
if
constexpr
(
Case
==
InputsFillCase
::
zeros
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
t
->
dtype
(),
InputType
,
{
...
...
@@ -804,9 +802,7 @@ void fillCase_special(Tensor *t) {
std
::
uniform_real_distribution
<>
dis_sign
(
-
1.0
,
1.0
);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
t
->
dtype
(),
InputType
,
{
InputType
*
data
=
t
->
rowwise_cpu_dptr
<
InputType
>
();
for
(
size_t
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
cols
;
++
j
)
{
const
size_t
idx
=
i
*
cols
+
j
;
for
(
size_t
idx
=
0
;
idx
<
size
;
++
idx
)
{
const
bool
is_negative
=
(
dis_sign
(
t
->
gen
())
<
0.0
);
double
val
=
dis
(
t
->
gen
());
if
(
is_negative
)
{
...
...
@@ -814,7 +810,6 @@ void fillCase_special(Tensor *t) {
}
data
[
idx
]
=
static_cast
<
InputType
>
(
val
);
}
}
});
}
t
->
set_scale_inv
(
1.0
);
...
...
tests/cpp/test_common.h
View file @
f8c2af4c
...
...
@@ -52,6 +52,7 @@ struct BytesToType<8> {
};
using
byte
=
uint8_t
;
using
int16
=
int16_t
;
using
int32
=
int32_t
;
using
int64
=
int64_t
;
using
fp32
=
float
;
...
...
@@ -70,6 +71,7 @@ using fp8e8m0 = uint8_t;
template
<
typename
T
>
struct
TypeInfo
{
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int64
,
fp32
,
...
...
tests/jax/test_custom_call_compute.py
View file @
f8c2af4c
...
...
@@ -4,6 +4,7 @@
import
jax
import
jax.numpy
as
jnp
import
numpy
as
np
import
pytest
from
jax
import
jit
,
value_and_grad
from
functools
import
reduce
...
...
@@ -18,11 +19,16 @@ from transformer_engine.jax.layernorm import layernorm
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
from
transformer_engine.jax.cpp_extensions.activation
import
_jax_act_lu
,
_jax_quantize_dact_dbias
from
transformer_engine.jax.cpp_extensions.normalization
import
_jax_layernorm
,
_jax_rmsnorm
from
transformer_engine.jax.cpp_extensions.normalization
import
(
_jax_layernorm
,
_jax_rmsnorm
,
is_norm_zero_centered_gamma_in_weight_dtype
,
)
from
transformer_engine.jax.cpp_extensions.quantization
import
(
_jax_quantize
,
_jax_quantize_dbias
,
)
from
transformer_engine.jax.cpp_extensions.misc
import
get_cudnn_version
from
transformer_engine.jax
import
cpp_extensions
as
tex
from
transformer_engine.jax.quantize
import
(
DelayedScaleQuantizer
,
...
...
@@ -33,7 +39,7 @@ from transformer_engine.jax.quantize import (
)
from
transformer_engine.jax.quantize
import
helper
from
transformer_engine.jax.activation
import
activation
from
transformer_engine.jax.dense
import
dense
,
grouped_dense
from
transformer_engine.jax.dense
import
dense
from
transformer_engine.jax.layernorm_dense
import
layernorm_dense
from
transformer_engine.jax.quantize
import
ScaledTensor1x
,
ScaledTensor2x
...
...
@@ -54,6 +60,7 @@ supported_scaling_modes = []
""" Find supported scaling modes"""
if
is_fp8_supported
:
supported_scaling_modes
.
append
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
supported_scaling_modes
.
append
(
ScalingMode
.
CURRENT_TENSOR_SCALING
)
if
is_mxfp8_supported
:
supported_scaling_modes
.
append
(
ScalingMode
.
MXFP8_1D_SCALING
)
...
...
@@ -71,8 +78,19 @@ def is_shape_supported_by_mxfp8(input_shape):
def
assert_bitwise_scaled_tensors
(
a
:
ScaledTensor
,
b
:
ScaledTensor
):
if
isinstance
(
a
,
ScaledTensor1x
)
and
isinstance
(
b
,
ScaledTensor1x
):
assert_allclose
(
a
.
data
,
b
.
data
)
assert
a
.
scaling_mode
==
b
.
scaling_mode
assert
a
.
scale_inv
.
dtype
==
b
.
scale_inv
.
dtype
if
a
.
scaling_mode
.
is_tensor_scaling
():
# Assert in dq_dtype as some unfused codepaths have an intermediate cast
# to an input dtype which reduces precision compared to everything in fp32
assert_allclose
(
a
.
scale_inv
,
b
.
scale_inv
,
dtype
=
a
.
dq_dtype
)
elif
a
.
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
# Compare MXFP8 scales as uint8
assert_allclose
(
a
.
scale_inv
.
astype
(
jnp
.
uint8
),
b
.
scale_inv
.
astype
(
jnp
.
uint8
))
else
:
raise
ValueError
(
f
"Unsupported scaling mode
{
a
.
scaling_mode
}
"
)
assert_allclose
(
a
.
data
,
b
.
data
)
elif
isinstance
(
a
,
ScaledTensor2x
)
and
isinstance
(
b
,
ScaledTensor2x
):
assert_bitwise_scaled_tensors
(
a
.
rowwise_tensor
,
b
.
rowwise_tensor
)
assert_bitwise_scaled_tensors
(
a
.
colwise_tensor
,
b
.
colwise_tensor
)
...
...
@@ -159,7 +177,12 @@ class TestActivation:
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
def
test_act_grad_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
):
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
[
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
]
)
def
test_act_grad_with_tensor_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
scaling_mode
):
x
=
random_inputs
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
...
...
@@ -170,7 +193,7 @@ class TestActivation:
)
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
S
caling
M
ode
.
DELAYED_TENSOR_SCALING
,
scaling_mode
=
s
caling
_m
ode
,
q_dtype
=
output_type
,
q_layout
=
QuantizeLayout
.
ROWWISE
,
)
...
...
@@ -188,8 +211,11 @@ class TestActivation:
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_act_forward_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_layout
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
[
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
]
)
def
test_act_forward_with_tensor_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_layout
,
scaling_mode
):
x
=
random_inputs
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
...
...
@@ -198,7 +224,7 @@ class TestActivation:
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
scaling_mode
=
S
caling
M
ode
.
DELAYED_TENSOR_SCALING
,
scaling_mode
=
s
caling
_m
ode
,
q_dtype
=
output_type
,
q_layout
=
q_layout
,
)
...
...
@@ -335,8 +361,20 @@ class TestNorm:
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_norm_grad_with_delayed_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_layout
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
[
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
]
)
def
test_norm_grad_with_tensor_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_layout
,
scaling_mode
,
):
"""
Test transformer_engine.jax.layernorm.layernorm
...
...
@@ -345,9 +383,7 @@ class TestNorm:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
q_dtype
=
out_dtype
,
q_layout
=
q_layout
,
scaling_mode
=
scaling_mode
,
q_dtype
=
out_dtype
,
q_layout
=
q_layout
)
self
.
_test_norm_grad
(
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
...
...
@@ -395,7 +431,41 @@ class TestNorm:
)
ref_mu
=
None
precise_comparison
=
True
if
get_cudnn_version
()
<
(
9
,
10
,
0
)
and
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
# Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
# do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
precise_comparison
=
False
elif
is_norm_zero_centered_gamma_in_weight_dtype
(
scaling_mode
):
# Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
# for zero-centered gamma always
precise_comparison
=
False
elif
scaling_mode
==
ScalingMode
.
CURRENT_TENSOR_SCALING
and
inp_dtype
!=
jnp
.
float32
:
# Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
# and writes intermediate results into the input dtype, which will slightly reduce precision
# if the input dtype is not float32
precise_comparison
=
False
if
precise_comparison
:
assert_bitwise_scaled_tensors
(
output
,
ref_out
)
else
:
if
isinstance
(
ref_out
,
ScaledTensor1x
):
assert_allclose
(
output
.
dequantize
(),
ref_out
.
dequantize
(),
dtype
=
out_dtype
)
elif
isinstance
(
ref_out
,
ScaledTensor2x
):
assert_allclose
(
output
.
rowwise_tensor
.
dequantize
(),
ref_out
.
rowwise_tensor
.
dequantize
(),
dtype
=
out_dtype
,
)
assert_allclose
(
output
.
colwise_tensor
.
dequantize
(),
ref_out
.
colwise_tensor
.
dequantize
(),
dtype
=
out_dtype
,
)
else
:
pytest
.
fail
(
"Unsupported output type"
)
assert_allclose
(
rsigma
,
ref_rsigma
,
dtype
=
inp_dtype
)
if
norm_type
==
"layernorm"
:
assert_allclose
(
mu
,
ref_mu
,
dtype
=
inp_dtype
)
...
...
@@ -406,8 +476,20 @@ class TestNorm:
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_norm_forward_with_delayed_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_layout
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
[
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
]
)
def
test_norm_forward_with_tensor_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_layout
,
scaling_mode
,
):
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
...
...
@@ -420,7 +502,7 @@ class TestNorm:
epsilon
=
epsilon
,
inp_dtype
=
inp_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
S
caling
M
ode
.
DELAYED_TENSOR_SCALING
,
scaling_mode
=
s
caling
_m
ode
,
q_layout
=
q_layout
,
)
...
...
@@ -447,17 +529,24 @@ QUANTIZE_OUTPUT_DTYPES = {
"L2"
:
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
],
}
ALL_QUANTIZE_TEST_SHAPES
=
[
(
32
,
64
),
(
2
,
64
,
32
),
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES
=
[
((
32
,
64
),
-
1
),
((
2
,
64
,
32
),
-
1
),
((
2
,
64
,
32
),
-
2
),
((
32
,
256
,
128
),
-
1
),
((
32
,
256
,
128
),
-
2
),
((
64
,
32
,
32
,
256
),
-
1
),
((
64
,
32
,
32
,
256
),
-
2
),
((
64
,
32
,
32
,
256
),
-
3
),
]
QUANTIZE_TEST_SHAPES
=
{
QUANTIZE_TEST_SHAPES
_AND_FLATTEN_AXES
=
{
"L0"
:
[
(
32
,
256
,
128
),
(
64
,
32
,
32
,
256
),
((
32
,
64
),
-
1
),
((
2
,
64
,
32
),
-
1
),
((
2
,
64
,
32
),
-
2
),
],
"L2"
:
ALL_QUANTIZE_TEST_SHAPES
,
"L2"
:
ALL_QUANTIZE_TEST_SHAPES
_AND_FLATTEN_AXES
,
}
QUANTIZATION_INPUT_DTYPE
=
{
...
...
@@ -469,9 +558,8 @@ QUANTIZATION_INPUT_DTYPE = {
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"input_shape
,flatten_axis
"
,
ALL_QUANTIZE_TEST_SHAPES
_AND_FLATTEN_AXES
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"flatten_axis"
,
[
-
1
,
-
2
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
COLWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
...
...
@@ -524,12 +612,11 @@ class TestFusedQuantize:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"input_shape
,flatten_axis
"
,
QUANTIZE_TEST_SHAPES
_AND_FLATTEN_AXES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
@
pytest_parametrize_wrapper
(
"flatten_axis"
,
[
-
1
,
-
2
])
def
test_quantize_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
q_layout
,
flatten_axis
):
...
...
@@ -538,6 +625,12 @@ class TestFusedQuantize:
):
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
if
(
flatten_axis
<
0
and
flatten_axis
+
len
(
input_shape
)
<=
0
)
or
flatten_axis
<=
0
:
pytest
.
skip
(
f
"Flatten axis
{
flatten_axis
}
is not supported for input shape
{
input_shape
}
. There"
" must be at least one axis on either side of the flatten_axis split."
)
key
=
jax
.
random
.
PRNGKey
(
0
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
...
...
@@ -630,16 +723,19 @@ class TestFusedQuantize:
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
COL
WISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
"q_layout"
,
[
QuantizeLayout
.
ROW
WISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_quantize_dact_dbias_delayed_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_layout
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
[
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
]
)
def
test_quantize_dact_dbias_tensor_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_layout
,
scaling_mode
):
self
.
_test_quantize_dact_dbias
(
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
out_dtype
=
out_dtype
,
scaling_mode
=
S
caling
M
ode
.
DELAYED_TENSOR_SCALING
,
scaling_mode
=
s
caling
_m
ode
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
q_layout
=
q_layout
,
...
...
@@ -830,7 +926,10 @@ class TestFusedDense:
Test layernorm_dense VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
in
(
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
,
):
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
...
...
@@ -916,7 +1015,10 @@ class TestFusedDense:
Test layernorm_mlp VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
in
(
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
,
):
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
...
...
@@ -1052,7 +1154,7 @@ fwd_bwd_dtypes = [
[
jnp
.
float8_e5m2
,
jnp
.
float8_e4m3fn
],
]
"""
@pytest_parametrize_wrapper(
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
)
...
...
@@ -1267,3 +1369,4 @@ class TestGroupedDense:
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
tests/jax/test_distributed_layernorm.py
View file @
f8c2af4c
...
...
@@ -34,6 +34,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
Float8CurrentScaling
(),
id
=
"CurrentScaling"
))
if
is_mxfp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
...
...
@@ -76,6 +77,8 @@ class TestDistributedLayernorm:
other_bytes
=
0
if
fp8_recipe
==
recipe
.
MXFP8BlockScaling
()
and
"dp"
in
mesh_axes
:
other_bytes
=
384
# required for small scale shapes that require padding
if
fp8_recipe
==
recipe
.
Float8CurrentScaling
():
allreduce_total_bytes
+=
jax_dtype
.
itemsize
# 1 * dtype for the amax reduction
return
generate_collectives_count
(
allreduce
=
allreduce_total_bytes
*
int
(
is_dp_enabled
),
allgather
=
0
,
other
=
other_bytes
)
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
f8c2af4c
...
...
@@ -41,6 +41,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
Float8CurrentScaling
(),
id
=
"CurrentScaling"
))
if
is_mxfp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
...
...
@@ -217,37 +218,10 @@ class TestDistributedLayernormMLP:
m_grad
,
s_grad
,
dtype
=
dtype
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
)
else
:
is_gated
=
len
(
activation_type
)
>
1
rtol
=
None
atol
=
None
if
is_gated
:
if
dtype
==
jnp
.
bfloat16
:
if
i
==
2
:
rtol
=
800
atol
=
9e-2
if
i
==
4
:
atol
=
300
rtol
=
1e-1
if
dtype
==
jnp
.
float16
:
if
i
==
1
:
# gamma
rtol
=
200
atol
=
1e-2
if
i
==
2
:
rtol
=
2000
atol
=
7e-2
if
i
==
4
and
fp8_recipe
==
recipe
.
MXFP8BlockScaling
():
# bias_1
# Accumulating dbias across a large tensor introduces a larger difference
rtol
=
200
atol
=
4e-2
if
i
==
4
and
fp8_recipe
==
recipe
.
DelayedScaling
():
rtol
=
2200
atol
=
9e-2
assert_allclose
(
multi_grads
[
i
],
single_grads
[
i
],
dtype
=
dtype
,
rtol
=
rtol
,
atol
=
atol
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
...
...
tests/jax/test_helper.py
View file @
f8c2af4c
...
...
@@ -10,47 +10,22 @@ import jax.numpy as jnp
import
numpy
as
np
from
utils
import
assert_allclose
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.common.recipe
import
DelayedScaling
,
MXFP8BlockScaling
,
Float8CurrentScaling
from
transformer_engine.common.recipe
import
Format
as
FP8Format
from
transformer_engine.jax
import
fp8_autocast
,
get_delayed_scaling
from
transformer_engine.jax.quantize
import
QuantizeConfig
,
is_fp8_available
,
AmaxComputeAlgo
from
transformer_engine.jax.quantize
import
(
QuantizeConfig
,
is_fp8_available
,
ScalingMode
,
update_collections
,
)
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
class
TestQuantizeConfig
(
unittest
.
TestCase
):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_initialize
(
self
):
margin
=
5.0
fp8_format
=
FP8Format
.
E4M3
amax_history_len
=
10
QuantizeConfig
.
initialize
(
margin
=
margin
,
fp8_format
=
fp8_format
,
amax_history_len
=
amax_history_len
)
self
.
assertEqual
(
QuantizeConfig
.
MARGIN
,
margin
,
f
"QuantizeConfig.MARGIN initialization failed, should be
{
margin
}
"
f
" but got
{
QuantizeConfig
.
MARGIN
}
."
,
)
self
.
assertEqual
(
QuantizeConfig
.
FP8_FORMAT
,
fp8_format
,
f
"QuantizeConfig.FP8_FORMAT initialization failed, should be
{
fp8_format
}
"
f
" but got
{
QuantizeConfig
.
FP8_FORMAT
}
."
,
)
self
.
assertEqual
(
QuantizeConfig
.
AMAX_HISTORY_LEN
,
amax_history_len
,
f
"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be
{
amax_history_len
}
"
f
" but got
{
QuantizeConfig
.
AMAX_HISTORY_LEN
}
."
,
)
QuantizeConfig
.
finalize
()
class
TestHelper
(
unittest
.
TestCase
):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_update_collections
(
self
):
...
...
@@ -61,19 +36,19 @@ class TestQuantizeConfig(unittest.TestCase):
"test1"
:
original_val
,
"test2"
:
original_val
,
}
updated_state
=
QuantizeConfig
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
updated_state
=
update_collections
({
"test1"
:
updated_val
},
original_state
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
original_state
=
flax
.
core
.
frozen_dict
.
FrozenDict
(
original_state
)
updated_state
=
QuantizeConfig
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
updated_state
=
update_collections
({
"test1"
:
updated_val
},
original_state
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
class
TestFP8Functions
(
unittest
.
TestCase
):
def
_check_defult_state
(
self
):
def
_check_def
a
ult_state
(
self
):
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
def
_compare_delay_scaling
(
self
,
ref
,
test
):
...
...
@@ -82,35 +57,92 @@ class TestFP8Functions(unittest.TestCase):
self
.
assertTrue
(
ref
.
amax_history_len
==
test
.
amax_history_len
)
self
.
assertTrue
(
ref
.
amax_compute_algo
==
test
.
amax_compute_algo
)
def
_compare_current_scaling
(
self
,
test
):
self
.
assertEqual
(
QuantizeConfig
.
MARGIN
,
test
.
margin
)
self
.
assertEqual
(
QuantizeConfig
.
FP8_FORMAT
,
test
.
fp8_format
)
self
.
assertEqual
(
QuantizeConfig
.
SCALING_MODE
,
ScalingMode
.
CURRENT_TENSOR_SCALING
)
def
_compare_mxfp8_scaling
(
self
,
test
):
self
.
assertEqual
(
QuantizeConfig
.
MARGIN
,
test
.
margin
)
self
.
assertEqual
(
QuantizeConfig
.
FP8_FORMAT
,
test
.
fp8_format
)
self
.
assertEqual
(
QuantizeConfig
.
SCALING_MODE
,
ScalingMode
.
MXFP8_1D_SCALING
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast
(
self
):
def
test_fp8_autocast
_delayed_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_defult_state
()
self
.
_check_def
a
ult_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
DelayedScaling
()):
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
DelayedScaling
())
self
.
_check_default_state
()
self
.
_check_defult_state
()
self
.
_check_def
a
ult_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_defult_state
()
self
.
_check_def
a
ult_state
()
ds
=
DelayedScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_defult_state
()
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
def
test_fp8_autocast_mxfp8_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
Float8CurrentScaling
()):
self
.
_check_default_state
()
self
.
_check_default_state
()
cs
=
Float8CurrentScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
cs
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_current_scaling
(
cs
)
self
.
_check_default_state
()
cs
=
Float8CurrentScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
cs
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_current_scaling
(
cs
)
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
def
test_fp8_autocast_mxfp8_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
MXFP8BlockScaling
()):
self
.
_check_default_state
()
self
.
_check_default_state
()
bs
=
MXFP8BlockScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
bs
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_mxfp8_scaling
(
bs
)
self
.
_check_default_state
()
bs
=
MXFP8BlockScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
bs
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_mxfp8_scaling
(
bs
)
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast_with_sharding_resource
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_defult_state
()
self
.
_check_def
a
ult_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
...
...
@@ -130,4 +162,4 @@ class TestFP8Functions(unittest.TestCase):
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
assertEqual
(
sr
,
global_mesh_resource
())
self
.
_check_defult_state
()
self
.
_check_def
a
ult_state
()
tests/jax/utils.py
View file @
f8c2af4c
...
...
@@ -13,7 +13,6 @@ import jax
import
jax.numpy
as
jnp
import
numpy
as
np
from
flax
import
linen
as
nn
from
flax.linen
import
partitioning
as
nn_partitioning
from
flax.linen.attention
import
combine_masks
from
jax
import
lax
,
vmap
from
jax
import
nn
as
jax_nn
...
...
@@ -97,16 +96,16 @@ def combine_biases(*masks: Optional[Array]):
return
mask
def
parameter
ize_by
_test_level
(
param_dict
:
dict
,
id_prefix
:
str
=
""
):
def
get_
parameter
s_for
_test_level
(
param_dict
:
dict
):
"""
Takes an input dictionary of parameters keyed by test type "L0", etc.
Returns
a list of pytest parameters to be used in a parameterized test for the current test typ
e
Returns
the parameters for the test level specified in the environment variabl
e
"""
DEFAULT_TEST_LEVEL
=
"L0"
test_level
=
os
.
environ
.
get
(
"NVTE_JAX_UNITTEST_LEVEL"
,
DEFAULT_TEST_LEVEL
)
if
test_level
not
in
param_dict
:
raise
ValueError
(
"Unsupported test level"
)
return
values_to_named_params
(
param_dict
[
test_level
]
,
id_prefix
)
return
param_dict
[
test_level
]
def
value_to_test_name_str
(
value
):
...
...
@@ -139,14 +138,18 @@ def pytest_parametrize_wrapper(param_name, param_values):
A wrapper for pytest.mark.parametrize to allow for automatic
naming of tests based on the parameter values.
"""
id_prefix
=
param_name
if
isinstance
(
param_values
,
dict
):
param_values
=
parameterize_by_test_level
(
param_values
,
id_prefix
=
param_name
)
elif
","
not
in
param_name
:
param_values
=
values_to_named_params
(
param_values
,
id_prefix
=
id_prefix
)
# If the values are split into a dictionary of test-levels, e.g. "L0", etc.,
# unwrap the selected level before proceeding.
param_values
=
get_parameters_for_test_level
(
param_values
)
if
","
not
in
param_name
:
# Multi-parameterize annotations are not supported in this wrapper
# and are just a passthrough to default pytest.mark.parametrize.
# E.g. @pytest_parametrize_wrapper("a,b", ((a_value1, b_value1), (a_value2, b_value2)))
# will be passed through to pytest.mark.parametrize as-is without pytest.param ids.
param_values
=
values_to_named_params
(
param_values
,
id_prefix
=
param_name
)
# Currently comma separated parameters in one parametrize call aren't supported for automatic naming
# and will just be passed through with default pytest names
def
decorator
(
func
):
return
pytest
.
mark
.
parametrize
(
param_name
,
param_values
)(
func
)
...
...
@@ -312,16 +315,22 @@ class DenseGeneral(nn.Module):
kernel_shape
=
tuple
(
inputs
.
shape
[
ax
]
for
ax
in
axis
)
+
features
kernel_param_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),
np
.
prod
(
features
))
kernel
=
nn_partitioning
.
param_with_axes
(
"kernel"
,
self
.
kernel_init
,
kernel_param_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
kernel
=
self
.
param
(
"kernel"
,
nn
.
with_logical_partitioning
(
self
.
kernel_init
,
self
.
kernel_axes
),
kernel_param_shape
,
self
.
dtype
,
)
kernel
=
jnp
.
asarray
(
kernel
,
input_dtype
)
kernel
=
jnp
.
reshape
(
kernel
,
kernel_shape
)
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
"bias"
,
self
.
bias_init
,
self
.
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
bias
=
self
.
param
(
"bias"
,
nn
.
with_logical_partitioning
(
self
.
bias_init
,
self
.
bias_axes
),
self
.
features
,
self
.
dtype
,
)
bias
=
bias
.
astype
(
input_dtype
)
else
:
...
...
@@ -418,9 +427,9 @@ class MlpBlock(nn.Module):
)
# Broadcast along length.
if
self
.
transpose_batch_sequence
:
x
=
nn
_partitioning
.
with_sharding
_constraint
(
x
,
(
"length"
,
"batch"
,
"mlp"
))
x
=
nn
.
with_logical
_constraint
(
x
,
(
"length"
,
"batch"
,
"mlp"
))
else
:
x
=
nn
_partitioning
.
with_sharding
_constraint
(
x
,
(
"batch"
,
"length"
,
"mlp"
))
x
=
nn
.
with_logical
_constraint
(
x
,
(
"batch"
,
"length"
,
"mlp"
))
output
=
DenseGeneral
(
inputs
.
shape
[
-
1
],
dtype
=
self
.
dtype
,
...
...
@@ -684,21 +693,13 @@ class MultiHeadAttention(nn.Module):
value
=
value
.
reshape
((
*
value
.
shape
[:
2
],
self
.
num_gqa_groups
,
self
.
head_dim
))
if
self
.
transpose_batch_sequence
:
query
=
nn_partitioning
.
with_sharding_constraint
(
query
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
)
)
key
=
nn_partitioning
.
with_sharding_constraint
(
key
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
))
value
=
nn_partitioning
.
with_sharding_constraint
(
value
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
)
)
query
=
nn
.
with_logical_constraint
(
query
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
))
key
=
nn
.
with_logical_constraint
(
key
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
))
value
=
nn
.
with_logical_constraint
(
value
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
))
else
:
query
=
nn_partitioning
.
with_sharding_constraint
(
query
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
)
)
key
=
nn_partitioning
.
with_sharding_constraint
(
key
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
))
value
=
nn_partitioning
.
with_sharding_constraint
(
value
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
)
)
query
=
nn
.
with_logical_constraint
(
query
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
))
key
=
nn
.
with_logical_constraint
(
key
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
))
value
=
nn
.
with_logical_constraint
(
value
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
))
if
decode
:
# Detect if we're initializing by absence of existing cache data.
...
...
@@ -805,9 +806,9 @@ class MultiHeadAttention(nn.Module):
x
=
x
.
reshape
((
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
]
*
x
.
shape
[
3
]))
if
self
.
transpose_batch_sequence
:
x
=
nn
_partitioning
.
with_sharding
_constraint
(
x
,
(
"length"
,
"batch"
,
"joined_kv"
))
x
=
nn
.
with_logical
_constraint
(
x
,
(
"length"
,
"batch"
,
"joined_kv"
))
else
:
x
=
nn
_partitioning
.
with_sharding
_constraint
(
x
,
(
"batch"
,
"length"
,
"joined_kv"
))
x
=
nn
.
with_logical
_constraint
(
x
,
(
"batch"
,
"length"
,
"joined_kv"
))
# Back to the original inputs dimensions.
...
...
@@ -853,8 +854,11 @@ class LayerNorm(nn.Module):
input_dtype
=
x
.
dtype
features
=
x
.
shape
[
-
1
]
scale
=
nn_partitioning
.
param_with_axes
(
"scale"
,
self
.
scale_init
,
(
features
,),
self
.
dtype
,
axes
=
(
"embed"
,)
scale
=
self
.
param
(
"scale"
,
nn
.
with_logical_partitioning
(
self
.
scale_init
,
(
"embed"
,)),
(
features
,),
self
.
dtype
,
)
x_
=
x
.
astype
(
jnp
.
float32
)
if
self
.
layernorm_type
==
"layernorm"
:
...
...
@@ -862,8 +866,11 @@ class LayerNorm(nn.Module):
var
=
jnp
.
mean
(
jnp
.
square
(
x_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
y
=
(
x_
-
mean
)
*
lax
.
rsqrt
(
var
+
self
.
epsilon
)
bias
=
nn_partitioning
.
param_with_axes
(
"ln_bias"
,
self
.
bias_init
,
(
features
,),
self
.
dtype
,
axes
=
(
"embed"
,)
bias
=
self
.
param
(
"ln_bias"
,
nn
.
with_logical_partitioning
(
self
.
bias_init
,
(
"embed"
,)),
(
features
,),
self
.
dtype
,
)
bias
=
jnp
.
asarray
(
bias
,
input_dtype
)
...
...
@@ -972,12 +979,11 @@ class RelativePositionBiases(nn.Module):
num_buckets
=
self
.
num_buckets
,
max_distance
=
self
.
max_distance
,
)
relative_attention_bias
=
nn_partitioning
.
param_with_axes
(
relative_attention_bias
=
self
.
param
(
"rel_embedding"
,
self
.
embedding_init
,
nn
.
with_logical_partitioning
(
self
.
embedding_init
,
(
"heads"
,
"relpos_buckets"
))
,
(
self
.
num_heads
,
self
.
num_buckets
),
jnp
.
float32
,
axes
=
(
"heads"
,
"relpos_buckets"
),
)
relative_attention_bias
=
jnp
.
asarray
(
relative_attention_bias
,
self
.
dtype
)
...
...
@@ -1555,14 +1561,16 @@ def sync_params_values(dst, src, transformations, sep="/"):
"""
src_values
=
{}
for
key
,
value
in
jax
.
tree_util
.
tree_leaves_with_path
(
src
):
normalized_key
=
sep
.
join
(
x
.
key
for
x
in
key
)
# Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key
=
sep
.
join
(
x
.
key
for
x
in
key
if
hasattr
(
x
,
"key"
))
src_values
[
normalized_key
]
=
value
flatten_dst
,
dst_tree_def
=
jax
.
tree_util
.
tree_flatten_with_path
(
dst
)
synced_dst_values
=
[]
for
key
,
value
in
flatten_dst
:
normalized_key
=
sep
.
join
(
x
.
key
for
x
in
key
)
# Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key
=
sep
.
join
(
x
.
key
for
x
in
key
if
hasattr
(
x
,
"key"
))
if
normalized_key
in
transformations
:
corresponding_src_key
=
transformations
[
normalized_key
]
else
:
...
...
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
View file @
f8c2af4c
...
...
@@ -16,6 +16,7 @@ import torch.distributed as dist
from
transformer_engine.common.recipe
import
(
DelayedScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
Format
,
Recipe
,
)
...
...
@@ -26,6 +27,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockwiseQTensor
def
_get_raw_data
(
quantized_tensor
):
...
...
@@ -34,6 +36,14 @@ def _get_raw_data(quantized_tensor):
assert
hasattr
(
quantized_tensor
,
"_data"
),
"Float8Tensor does not have _data attribute"
assert
quantized_tensor
.
_data
.
dtype
==
torch
.
uint8
,
"Float8Tensor _data must be uint8"
return
quantized_tensor
.
_data
elif
isinstance
(
quantized_tensor
,
Float8BlockwiseQTensor
):
assert
hasattr
(
quantized_tensor
,
"_rowwise_data"
),
"Float8BlockwiseQTensor does not have _rowwise_data attribute"
assert
(
quantized_tensor
.
_rowwise_data
.
dtype
==
torch
.
uint8
),
"Float8BlockwiseQTensor _rowwise_data must be uint8"
return
quantized_tensor
.
_rowwise_data
else
:
raise
ValueError
(
f
"Unsupported quantized tensor type:
{
type
(
quantized_tensor
)
}
"
)
...
...
@@ -435,15 +445,15 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
...
...
@@ -539,12 +549,13 @@ def _test_zero_1(dp_group):
def
quantization_recipe
(
quantization
)
->
Recipe
:
"""Quantization recipe setup"""
fp8_format
=
Format
.
HYBRID
if
quantization
==
"fp8"
:
return
DelayedScaling
(
fp8_format
=
Format
.
HYBRID
,
amax_history_len
=
32
,
amax_compute_algo
=
"max"
)
return
DelayedScaling
(
fp8_format
=
fp8_format
,
amax_history_len
=
32
,
amax_compute_algo
=
"max"
)
elif
quantization
==
"fp8_cs"
:
return
Float8CurrentScaling
()
return
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
elif
quantization
==
"fp8_block"
:
return
Float8BlockScaling
(
fp8_format
=
fp8_format
)
else
:
raise
ValueError
(
f
"Unsupported quantization:
{
quantization
}
"
)
...
...
@@ -568,15 +579,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
...
...
@@ -593,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
optimizer_fp8
=
MiniZero_1
([
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
)
optimizer
=
MiniZero_1
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
_
in
range
(
100
):
for
i
in
range
(
100
):
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
w_fp8
.
main_grad
.
zero_
()
w
.
main_grad
.
zero_
()
...
...
@@ -654,7 +665,9 @@ def main(argv=None, namespace=None):
dist
.
init_process_group
(
**
dist_init_kwargs
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
None
,
choices
=
[
"fp8"
,
"fp8_cs"
])
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
None
,
choices
=
[
"fp8"
,
"fp8_cs"
,
"fp8_block"
]
)
args
=
parser
.
parse_args
(
argv
,
namespace
)
dp_group
=
dist
.
new_group
(
backend
=
"nccl"
)
...
...
tests/pytorch/distributed/run_gemm_with_overlap.py
View file @
f8c2af4c
...
...
@@ -21,7 +21,11 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.module.base
import
get_cublas_workspace_size_bytes
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_cublas_workspace_size_bytes
,
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
DeprecationWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
...
...
@@ -57,7 +61,11 @@ def _parse_args(argv=None, namespace=None):
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"RNG seed."
)
parser
.
add_argument
(
"--fp8"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enables the te.fp8_autocast() context."
"--quantization"
,
type
=
str
.
lower
,
default
=
"none"
,
choices
=
[
"none"
,
"fp8"
,
"mxfp8"
],
help
=
"Quantization recipe"
,
)
parser
.
add_argument
(
"--fp8-output"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Get FP8 output from GEMM."
...
...
@@ -155,9 +163,9 @@ def _parse_args(argv=None, namespace=None):
if
opts
.
atomic
:
warnings
.
warn
(
"Atomic GEMM is not supported with bulk overlap."
)
opts
.
atomic
=
False
if
opts
.
fp8
:
if
opts
.
quantization
!=
"none"
:
warnings
.
warn
(
"Bulk overlap is supported in FP8 but only tested in BF16."
)
opts
.
fp8
=
False
opts
.
quantization
=
"none"
elif
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
if
opts
.
atomic
:
setattr
(
opts
,
"atomic_rs_p2p"
,
opts
.
p2p
)
...
...
@@ -165,8 +173,11 @@ def _parse_args(argv=None, namespace=None):
if
opts
.
atomic
:
if
not
te
.
fp8
.
check_fp8_support
():
assert
not
opts
.
fp8
,
"Atomic GEMM is only supported in FP8."
opts
.
fp8
=
True
assert
opts
.
quantization
==
"none"
,
"Atomic GEMM is only supported in FP8."
opts
.
quantization
=
"fp8"
if
opts
.
fp8_output
:
assert
ops
.
quantization
==
"fp8"
,
"FP8 output is only supported with FP8 compute."
return
opts
...
...
@@ -303,7 +314,11 @@ def _main(opts):
inp_shape
=
(
opts
.
seq_length
,
opts
.
batch_size
,
hidden_size
)
outer_size
=
reduce
(
operator
.
mul
,
inp_shape
[:
-
1
],
1
)
buffer_dtype
=
torch
.
bfloat16
if
opts
.
fp8
and
not
opts
.
bulk_overlap
and
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
if
(
opts
.
quantization
!=
"none"
and
not
opts
.
bulk_overlap
and
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
):
buffer_dtype
=
torch
.
uint8
ub_obj
=
(
tex
.
CommOverlapP2P
(
...
...
@@ -450,6 +465,8 @@ def _main(opts):
inp2_g
=
torch
.
nn
.
functional
.
gelu
(
ref_g
)
# pylint: disable=not-callable
ref2_g
=
torch
.
matmul
(
inp2_g
,
ker2_g
)
# Initialize quantizers
with_quantized_compute
=
opts
.
quantization
!=
"none"
inp_quantizer
=
None
ker_quantizer
=
None
out_quantizer
=
None
...
...
@@ -457,7 +474,7 @@ def _main(opts):
inp2_quantizer
=
None
ker2_quantizer
=
None
out2_quantizer
=
None
if
opts
.
fp8
:
if
opts
.
quantization
==
"
fp8
"
:
# Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms
=
6
if
ub_obj2
is
not
None
else
3
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
...
...
@@ -502,11 +519,23 @@ def _main(opts):
out2_quantizer
=
Float8Quantizer
(
fp8_scales
[
5
].
clone
(),
fp8_amaxes
[
5
].
clone
(),
fp8_dtype
)
elif
opts
.
quantization
==
"mxfp8"
:
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
inp_quantizer
=
MXFP8Quantizer
(
fp8_dtype
,
columnwise
=
False
)
ker_quantizer
=
MXFP8Quantizer
(
fp8_dtype
)
if
opts
.
bulk_overlap
and
opts
.
comm_type
==
tex
.
CommOverlapType
.
RS
:
bulk_inp_quantizer
=
MXFP8Quantizer
(
fp8_dtype
,
columnwise
=
False
)
elif
ub_obj2
is
not
None
:
inp2_quantizer
=
MXFP8Quantizer
(
fp8_dtype
,
columnwise
=
False
)
ker2_quantizer
=
MXFP8Quantizer
(
fp8_dtype
)
# Cast input to Float8Tensor
# Quantize tensors
if
with_quantized_compute
:
# Quantize input tensor
inp_fp8
=
inp_quantizer
(
inp
)
#
Cast
kernel t
o Float8T
ensor
#
Quantize
kernel tensor
kernel_t_fp8
=
ker_quantizer
(
kernel_t
)
if
opts
.
bulk_overlap
and
opts
.
comm_type
==
tex
.
CommOverlapType
.
RS
:
bulk_inp_fp8
=
bulk_inp_quantizer
(
bulk_inp
)
...
...
@@ -543,31 +572,40 @@ def _main(opts):
)
# Set up comm/compute buffers
ag_out
=
None
rs_out
=
None
rs_out2
=
None
if
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
if
opts
.
bulk_overlap
:
ub_obj
.
copy_into_buffer
(
bulk_inp
,
bulk_inp_quantizer
,
True
)
ag_out
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj
,
bulk_inp
,
bulk_inp_quantizer
,
tp_group
,
)
gemm_inp
=
inp
else
:
ub_obj
.
copy_into_buffer
(
inp_fp8
if
opts
.
fp8
else
inp
,
inp_quantizer
,
True
)
gemm_inp
=
ub_obj
.
get_buffer
(
inp_quantizer
,
False
,
inp_g
.
size
())
ag_out
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj
,
inp_fp8
if
with_quantized_compute
else
inp
,
inp_quantizer
,
tp_group
,
)
gemm_inp
=
ag_out
if
ub_obj2
is
not
None
:
if
opts
.
fp8
and
opts
.
fp8_output
:
ub_obj2
.
set_buffer_params
(
out_quantizer
)
rs_out2
=
torch
.
empty
(
(
outer_size
//
tp_size
,
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
else
:
if
opts
.
bulk_overlap
:
ub_obj
.
copy_into_buffer
(
bulk_inp_fp8
if
opts
.
fp8
else
bulk_inp
,
bulk_inp_quantizer
,
False
)
if
opts
.
fp8
:
ub_obj
.
set_buffer_params
(
bulk_inp_quantizer
)
elif
opts
.
fp8
and
opts
.
fp8_output
:
ub_obj
.
set_buffer_params
(
out_quantizer
)
gemm_inp
=
inp_fp8
if
opts
.
fp8
else
inp
if
opts
.
quantization
==
"none"
:
ub_obj
.
copy_into_buffer
(
bulk_inp
,
local_chunk
=
False
)
if
opts
.
quantization
==
"fp8"
:
ub_obj
.
copy_into_buffer
(
bulk_inp_fp8
.
_data
,
local_chunk
=
False
)
elif
opts
.
quantization
==
"mxfp8"
:
ub_obj
.
copy_into_buffer
(
bulk_inp_fp8
.
_rowwise_data
,
local_chunk
=
False
)
gemm_inp
=
inp_fp8
if
with_quantized_compute
else
inp
rs_out
=
torch
.
empty
(
(
outer_size
//
tp_size
,
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
...
...
@@ -626,7 +664,7 @@ def _main(opts):
if
opts
.
use_cuda_graphs
:
# Trace the CUDA graph first
g
=
torch
.
cuda
.
CUDAGraph
()
if
opts
.
fp8
:
if
with_quantized_compute
:
if
ub_obj
is
None
:
with
torch
.
cuda
.
graph
(
g
):
all_outputs
=
_fp8_gemm
()
...
...
@@ -646,7 +684,7 @@ def _main(opts):
else
:
for
i
in
range
(
total_iters
):
if
opts
.
fp8
:
if
with_quantized_compute
:
start_events
[
i
].
record
()
all_outputs
=
_fp8_gemm
()
end_events
[
i
].
record
()
...
...
@@ -691,10 +729,22 @@ def _main(opts):
output_info
=
""
if
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
# Bulk overlap AG output is already gathered
test_out
=
ub_obj
.
get_buffer
(
bulk_inp_quantizer
,
False
)
test_out
=
ag_out
if
bulk_inp_quantizer
is
None
:
test_out
=
ub_obj
.
get_buffer
(
False
)
else
:
test_out
=
Float8Tensor
(
shape
=
test_out
.
shape
,
dtype
=
torch
.
bfloat16
,
data
=
ub_obj
.
get_buffer
(
False
),
fp8_scale
=
bulk_inp_quantizer
.
scale
,
fp8_dtype
=
bulk_inp_quantizer
.
dtype
,
quantizer
=
bulk_inp_quantizer
,
)
else
:
# Bulk overlap RS output needs to be gathered
out_local
=
ub_obj
.
get_buffer
(
bulk_inp_quantizer
,
True
)
out_local
=
ub_obj
.
get_buffer
(
True
)
output_info
+=
f
"rs_output:
{
list
(
out_local
.
shape
)
}
| "
test_out
=
te
.
distributed
.
gather_along_first_dim
(
out_local
,
tp_group
)[
0
]
...
...
@@ -765,8 +815,8 @@ def _main(opts):
m
=
torch
.
argmax
(
diff
)
abs_err
=
diff
[
m
].
item
()
rel_err
=
abs_err
/
max
(
abs
(
ref_out
.
flatten
()[
m
].
item
()),
1e-5
)
rtol
=
0.
125
if
opts
.
fp8
else
0.
02
atol
=
0.0
625
if
opts
.
fp8
else
0.0
01
rtol
=
0.
02
if
opts
.
quantization
==
"none"
else
0.
125
atol
=
0.0
01
if
opts
.
quantization
==
"none"
else
0.0
625
if
rel_err
>
rtol
and
abs_err
>
atol
:
numerics_failed
=
True
numerics_info
=
(
...
...
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
f8c2af4c
...
...
@@ -17,7 +17,12 @@ import torch
import
torch.distributed
as
dist
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
Format
,
DelayedScaling
,
Float8CurrentScaling
from
transformer_engine.common.recipe
import
(
DelayedScaling
,
Float8CurrentScaling
,
Format
,
MXFP8BlockScaling
,
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
DeprecationWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
...
...
@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None):
"--quantization"
,
type
=
str
.
lower
,
default
=
"none"
,
choices
=
[
"none"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
],
choices
=
[
"none"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
],
help
=
"Quantization recipe"
,
)
parser
.
add_argument
(
...
...
@@ -414,6 +419,8 @@ def _train(opts):
)
elif
opts
.
quantization
==
"fp8_current_scaling"
:
fp8_recipe
=
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
elif
opts
.
quantization
==
"mxfp8"
:
fp8_recipe
=
MXFP8BlockScaling
()
# Prepare random input tensors
test_x
=
torch
.
randn
(
input_shape
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
True
)
...
...
tests/pytorch/distributed/run_numerics.py
View file @
f8c2af4c
...
...
@@ -174,7 +174,7 @@ def _get_tolerances(dtype):
if
dtype
==
torch
.
bfloat16
:
return
{
"rtol"
:
1.6e-2
,
"atol"
:
1e-5
}
if
dtype
==
torch
.
float32
:
return
{
"rtol"
:
1.3e-6
,
"atol"
:
1
e-5
}
return
{
"rtol"
:
1.3e-6
,
"atol"
:
4
e-5
}
raise
ValueError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
...
...
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
View file @
f8c2af4c
...
...
@@ -15,6 +15,9 @@ if torch.cuda.device_count() < 2:
pytest
.
skip
(
"cast_master_weights_to_fp8 test needs at least 2 GPUs."
)
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
NUM_PROCS
:
int
=
min
(
2
,
torch
.
cuda
.
device_count
())
...
...
@@ -28,8 +31,10 @@ def _run_test(quantization):
assert
result
.
returncode
==
0
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
"fp8"
,
"fp8_cs"
])
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
"fp8"
,
"fp8_cs"
,
"fp8_block"
])
def
test_cast_master_weights_to_fp8
(
quantization
):
if
not
fp8_available
:
if
quantization
in
(
"fp8"
,
"fp8_cs"
)
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"fp8_block"
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
_run_test
(
quantization
)
tests/pytorch/distributed/test_comm_gemm_overlap.py
View file @
f8c2af4c
...
...
@@ -21,6 +21,7 @@ if torch.cuda.device_count() < 2:
pytest
.
skip
(
"Comm+GEMM overlap requires at least 2 GPUs."
)
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
RNG_SEED
:
int
=
42
SEQ_LENGTH
:
int
=
1024
...
...
@@ -56,7 +57,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch
.
_dynamo
.
reset
()
def
_run_gemm_with_overlap
(
comm_type
,
bulk
,
p2p
,
atomic
,
fp8
):
def
_run_gemm_with_overlap
(
comm_type
,
bulk
,
p2p
,
atomic
,
quantization
):
test_path
=
TEST_ROOT
/
"run_gemm_with_overlap.py"
test_cmd
=
LAUNCH_CMD
+
[
str
(
test_path
),
...
...
@@ -72,10 +73,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
if
bulk
:
test_cmd
.
append
(
"--bulk-overlap"
)
else
:
if
fp8
:
if
not
fp8_available
:
if
quantization
==
"fp8"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
test_cmd
.
append
(
"--fp8"
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
test_cmd
.
append
(
f
"--quantization=
{
quantization
}
"
)
if
p2p
:
test_cmd
.
append
(
"--p2p"
)
if
atomic
:
...
...
@@ -114,8 +116,10 @@ def _run_layer_with_overlap(
test_cmd
.
append
(
"--overlap-rs-dgrad"
)
if
fp8
:
if
not
fp8_available
:
if
quantization
in
(
"fp8_delayed_scaling"
,
"fp8_current_scaling"
)
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
test_cmd
.
append
(
"--fp8"
)
test_cmd
.
append
(
f
"--quantization=
{
quantization
}
"
)
...
...
@@ -137,51 +141,34 @@ def _run_layer_with_overlap(
raise
AssertionError
(
result
.
stderr
.
decode
())
@
pytest
.
mark
.
parametrize
(
"fp8"
,
(
False
,
True
),
ids
=
[
" BF16 - RING-EXCHANGE "
,
" FP8 - RING-EXCHANGE "
],
)
def
test_split_all_gather_overlaps
(
fp8
):
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
"none"
,
"fp8"
,
"mxfp8"
))
def
test_split_all_gather_overlaps
(
quantization
):
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap
(
"AG"
,
False
,
True
,
False
,
fp8
)
_run_gemm_with_overlap
(
"AG"
,
False
,
True
,
False
,
quantization
)
@
pytest
.
mark
.
parametrize
(
"fp8,p2p"
,
[
(
False
,
False
),
(
False
,
True
),
(
True
,
False
),
(
True
,
True
),
],
ids
=
[
" BF16 - PIPELINE "
,
" BF16 - RING-EXCHANGE "
,
" FP8 - PIPELINE "
,
" FP8 - RING-EXCHANGE "
,
],
)
def
test_split_reduce_scatter_overlaps
(
fp8
,
p2p
):
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
"none"
,
"fp8"
,
"mxfp8"
))
@
pytest
.
mark
.
parametrize
(
"p2p"
,
(
False
,
True
))
def
test_split_reduce_scatter_overlaps
(
quantization
,
p2p
):
"""
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
"""
_run_gemm_with_overlap
(
"RS"
,
False
,
p2p
,
False
,
fp8
)
_run_gemm_with_overlap
(
"RS"
,
False
,
p2p
,
False
,
quantization
)
@
pytest
.
mark
.
parametrize
(
"comm_type,
fp8
, connections"
,
"comm_type,
quantization
, connections"
,
[
(
"AG"
,
False
,
1
),
(
"RS"
,
False
,
1
),
(
"RS"
,
True
,
1
),
(
"AG"
,
False
,
8
),
(
"RS"
,
False
,
8
),
(
"RS"
,
True
,
8
),
(
"AG"
,
"none"
,
1
),
(
"RS"
,
"none"
,
1
),
(
"RS"
,
"fp8"
,
1
),
(
"AG"
,
"none"
,
8
),
(
"RS"
,
"none"
,
8
),
(
"RS"
,
"fp8"
,
8
),
],
ids
=
[
"ALL-GATHER - BF16 - 1 connections"
,
...
...
@@ -192,7 +179,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p):
"REDUCE-SCATTER - FP8 - 8 connections"
,
],
)
def
test_bulk_overlaps
(
comm_type
,
fp8
,
connections
):
def
test_bulk_overlaps
(
comm_type
,
quantization
,
connections
):
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
...
...
@@ -203,10 +190,10 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" 9.0 (HOPPER ARCH)."
)
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"8"
_run_gemm_with_overlap
(
comm_type
,
True
,
False
,
False
,
fp8
)
_run_gemm_with_overlap
(
comm_type
,
True
,
False
,
False
,
quantization
)
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"1"
else
:
_run_gemm_with_overlap
(
comm_type
,
True
,
False
,
False
,
fp8
)
_run_gemm_with_overlap
(
comm_type
,
True
,
False
,
False
,
quantization
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -258,15 +245,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
],
ids
=
[
" DELAYED SCALING "
,
" CURRENT SCALING "
],
)
@
pytest
.
mark
.
parametrize
(
"fp8"
,
(
True
,),
ids
=
[
" FP8 "
,
],
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
],
)
@
pytest
.
mark
.
parametrize
(
"layer_type,linear_parallel_mode,overlap_rs_dgrad"
,
...
...
@@ -286,15 +265,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
)
),
ids
=
[
f
"
{
te
.
Linear
.
__name__
}
- ROW-PARALLEL
"
,
f
"
{
te
.
Linear
.
__name__
}
- COL-PARALLEL -
BULK DGRAD/WGRAD
"
,
f
"
{
te
.
Linear
.
__name__
}
- COL-PARLALEL -
DGRAD+RS
"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
- ROW-PARALLEL
"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
- COL-PARALLEL -
BULK DGRAD/WGRAD
"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
- COL-PARALLEL -
DGRAD+RS
"
,
f
"
{
te
.
Linear
.
__name__
}
-row_tensor_parallel
"
,
f
"
{
te
.
Linear
.
__name__
}
-col_tensor_parallel-
BULK DGRAD/WGRAD"
,
f
"
{
te
.
Linear
.
__name__
}
-col_tensor_parallel-
DGRAD+RS"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
-row_tensor_parallel
"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
-col_tensor_parallel-
BULK DGRAD/WGRAD"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
-col_tensor_parallel-
DGRAD+RS"
,
]
+
[
"
"
+
" -
"
.
join
(
test_name_parts
)
+
" "
"
-
"
.
join
(
test_name_parts
)
for
test_name_parts
in
zip
(
[
layer
.
__name__
for
layer
in
TE_LAYERS
[
2
:]
for
_
in
range
(
2
)],
[
"BULK DGRAD/WGRAD"
,
"DGRAD+RS"
]
*
len
(
TE_LAYERS
[
2
:]),
...
...
@@ -302,12 +281,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
],
)
def
test_layers_with_overlap_fp8
(
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
fp8
,
quantization
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
quantization
,
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap
(
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
fp8
,
quantization
)
_run_layer_with_overlap
(
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
True
,
quantization
)
@
pytest
.
mark
.
parametrize
(
...
...
@@ -354,22 +336,11 @@ def test_multi_layer_with_overlap_bf16(
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
],
ids
=
[
" DELAYED SCALING "
,
" CURRENT SCALING "
],
)
@
pytest
.
mark
.
parametrize
(
"fp8"
,
(
True
,),
ids
=
[
" FP8 "
,
],
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
],
)
@
pytest
.
mark
.
parametrize
(
"num_layers"
,
(
2
,),
ids
=
[
" 2 layers "
,
],
)
@
pytest
.
mark
.
parametrize
(
"layer_type,linear_parallel_mode,overlap_rs_dgrad"
,
...
...
@@ -381,7 +352,7 @@ def test_multi_layer_with_overlap_bf16(
)
),
ids
=
[
"
"
+
" -
"
.
join
(
test_name_parts
)
+
" "
"
-
"
.
join
(
test_name_parts
)
for
test_name_parts
in
zip
(
[
te
.
TransformerLayer
.
__name__
for
_
in
range
(
2
)],
[
"BULK DGRAD/WGRAD"
,
"DGRAD+RS"
],
...
...
@@ -389,11 +360,11 @@ def test_multi_layer_with_overlap_bf16(
],
)
def
test_multi_layer_with_overlap_fp8
(
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
fp8
,
quantization
,
num_layers
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
quantization
,
num_layers
):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap
(
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
fp8
,
quantization
,
num_layers
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
True
,
quantization
,
num_layers
)
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
f8c2af4c
...
...
@@ -19,7 +19,6 @@ import torch
import
transformer_engine
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
...
...
@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear
,
UserbuffersForwardLinear
,
)
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensor
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
# Import utility functions
...
...
@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
quantization_list
:
list
[
Optional
[
str
]]
=
[
None
]
if
fp8_available
:
quantization_list
.
append
(
"fp8"
)
if
mxfp8_available
:
quantization_list
.
append
(
"mxfp8"
)
# Check if there are multiple GPUs
if
torch
.
cuda
.
device_count
()
<
2
:
...
...
@@ -51,7 +59,7 @@ class ModelConfig:
num_heads
:
int
head_dim
:
int
dtype
:
torch
.
dtype
fp8
:
bool
quantization
:
Optional
[
str
]
@
property
def
hidden_size
(
self
):
...
...
@@ -129,11 +137,15 @@ def make_reference_and_test_tensors(
ref
=
torch
.
rand
(
shape
,
dtype
=
ref_dtype
,
device
=
ref_device
)
# Make copy of tensor
if
test_is_fp8
:
test
=
Float8Tensor
.
to_float8
(
ref
)
else
:
test
=
ref
.
to
(
device
=
test_device
,
dtype
=
test_dtype
)
if
test
.
data_ptr
()
==
ref
.
data_ptr
():
if
test_is_fp8
:
quantizer
=
Float8Quantizer
(
scale
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
test_device
),
amax
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
test_device
),
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
)
test
=
quantizer
(
test
)
elif
test
.
data_ptr
()
==
ref
.
data_ptr
():
test
=
test
.
clone
()
# Make sure reference and test tensors represent exact same values
...
...
@@ -145,6 +157,21 @@ def make_reference_and_test_tensors(
return
ref
,
test
def
make_recipe
(
name
:
Optional
[
str
]
=
None
)
->
Optional
[
Recipe
]:
"""Make recipe for quantization scheme"""
if
name
is
None
:
return
None
if
name
==
"fp8"
:
return
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
,
)
if
name
==
"mxfp8"
:
return
transformer_engine
.
common
.
recipe
.
MXFP8BlockScaling
(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
,
)
raise
ValueError
(
f
"Unsupported quantization scheme (
{
name
}
)"
)
def
_test_linear
(
*
,
model_config
:
ModelConfig
,
...
...
@@ -155,7 +182,8 @@ def _test_linear(
weight_requires_grad
:
bool
=
True
,
)
->
None
:
dtype
=
model_config
.
dtype
fp8_compute
=
model_config
.
fp8
quantization
=
model_config
.
quantization
quantized_compute
=
quantization
is
not
None
# Distributed process group
process_group
=
world_group
()
...
...
@@ -175,14 +203,19 @@ def _test_linear(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
fp8
_compute
,
test_is_fp8
=
quantized
_compute
,
)
if
isinstance
(
x_test
,
QuantizedTensor
):
with
torch
.
no_grad
():
x_test
=
x_test
.
dequantize
().
requires_grad_
()
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
fp8
_compute
,
test_is_fp8
=
quantized
_compute
,
)
if
isinstance
(
w_test
,
QuantizedTensor
):
w_test
=
w_test
.
dequantize
()
b_ref
,
b_test
=
None
,
None
if
bias
:
if
tensor_parallel_mode
==
"row"
:
...
...
@@ -198,9 +231,11 @@ def _test_linear(
out_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
test_is_fp8
=
fp8
_compute
,
test_is_fp8
=
quantized
_compute
,
requires_grad
=
False
,
)
if
isinstance
(
dy_test
,
QuantizedTensor
):
dy_test
=
dy_test
.
dequantize
()
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
x_ref
,
w_ref
)
...
...
@@ -265,21 +300,15 @@ def _test_linear(
x_test
.
requires_grad_
()
# Implementation with fusible operation
with
te
.
fp8_model_init
(
enabled
=
fp8_compute
):
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
ops
=
[]
linear_op
=
None
bias_op
=
None
if
tensor_parallel_mode
==
"column"
:
userbuffers_options
=
{}
if
not
weight_requires_grad
:
if
fp8_compute
:
userbuffers_options
[
"comm_name"
]
=
"fc1"
else
:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options
[
"comm_name"
]
=
"qkv"
else
:
userbuffers_options
[
"comm_name"
]
=
"qkv"
linear_op
=
te_ops
.
BasicLinear
(
...
...
@@ -322,7 +351,7 @@ def _test_linear(
bias_op
.
bias
.
copy_
(
b_test
)
del
w_test
del
b_test
with
te
.
fp8_autocast
(
enabled
=
fp8_comput
e
):
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recip
e
):
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
...
...
@@ -338,7 +367,7 @@ def _test_linear(
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
fp8
_compute
:
if
quantized
_compute
:
tols
=
dtype_tols
(
model
[
0
].
weight
.
_fp8_dtype
if
is_float8_tensor
(
model
[
0
].
weight
)
...
...
@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None:
for
test_config
in
itertools
.
product
(
(
False
,
True
),
# bias
(
"column"
,
"row"
),
# tensor_parallel_mode
(
False
,
Tru
e
),
# weight_requires_grad
(
True
,
Fals
e
),
# weight_requires_grad
):
if
rank
==
0
:
print
(
f
"Running _test_linear with
{
test_config
=
}
"
)
...
...
@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1:
@
pytest
.
mark
.
parametrize
(
"world_size"
,
_world_sizes
)
@
pytest
.
mark
.
parametrize
(
"
fp8"
,
(
False
,
True
)
)
@
pytest
.
mark
.
parametrize
(
"
quantization"
,
quantization_list
)
def
test_fuser_ops_with_userbuffers
(
*
,
world_size
:
int
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
fp8
:
bool
,
quantization
:
Optional
[
str
]
,
)
->
None
:
"""Launch parallel job and run tests"""
# Skip invalid configurations
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
# Parallel job launcher
command
=
[]
if
tex
.
ubuf_built_with_mpi
():
...
...
@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers(
str
(
dtype
),
)
)
if
fp8
:
command
.
app
end
(
"--
fp8"
)
if
quantization
is
not
None
:
command
.
ext
end
(
(
"--
quantization"
,
quantization
)
)
# Environment
env
=
dict
(
os
.
environ
)
...
...
@@ -445,12 +470,12 @@ def main() -> None:
# Parse command-line arguments
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--parallel"
,
action
=
"store_true"
,
help
=
"Run parallel tests"
)
parser
.
add_argument
(
"--sequence-length"
,
type
=
int
,
default
=
3
2
)
parser
.
add_argument
(
"--sequence-length"
,
type
=
int
,
default
=
2
56
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--head-dim"
,
type
=
int
,
default
=
3
2
)
parser
.
add_argument
(
"--head-dim"
,
type
=
int
,
default
=
2
56
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"bfloat16"
)
parser
.
add_argument
(
"--
fp8"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--
quantization"
,
type
=
str
,
default
=
None
)
args
=
parser
.
parse_args
()
# Run parallel tests if needed
...
...
@@ -463,14 +488,17 @@ def main() -> None:
num_heads
=
args
.
num_heads
,
head_dim
=
args
.
head_dim
,
dtype
=
str_to_dtype
(
args
.
dtype
),
fp8
=
args
.
fp8
,
quantization
=
args
.
quantization
,
)
# Initialize Userbuffers
group
=
world_group
()
# Initialize NCCL
bootstrap_backend
=
"mpi"
if
launcher
()
==
"ompi"
else
"nccl"
userbuffer_configs
=
{
"fc1_dgrad"
:
{
"method"
:
"pipeline"
},
# Overlap dgrad RS with dgrad GEMM
"fc1_dgrad"
:
{
"method"
:
"ring_exchange"
,
"fp8_buf"
:
False
,
},
# Overlap dgrad RS with dgrad GEMM
}
te
.
module
.
base
.
initialize_ub
(
[
...
...
@@ -478,7 +506,7 @@ def main() -> None:
model_config
.
num_heads
*
model_config
.
head_dim
,
],
torch
.
distributed
.
get_world_size
(
group
),
use_fp8
=
model_config
.
fp8
,
use_fp8
=
model_config
.
quantization
is
not
None
,
dtype
=
model_config
.
dtype
,
bootstrap_backend
=
bootstrap_backend
,
ub_cfgs
=
userbuffer_configs
,
...
...
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
View file @
f8c2af4c
...
...
@@ -2,12 +2,16 @@
#
# See LICENSE for license information.
import
os
,
sys
,
logging
import
os
import
sys
import
logging
from
contextlib
import
nullcontext
import
torch
import
torch.distributed
as
dist
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.attention
import
get_cu_seqlens_on_cp_rank
from
transformer_engine.pytorch.attention.dot_product_attention.context_parallel
import
(
get_cu_seqlens_on_cp_rank
,
)
import
transformer_engine_torch
as
tex
from
test_fused_attn_with_cp
import
model_configs_flash_attn
,
model_configs_fused_attn
from
transformer_engine.pytorch.fp8
import
fp8_autocast
...
...
tests/pytorch/fused_attn/test_fused_attn.py
View file @
f8c2af4c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
functools
import
logging
import
math
import
os
from
importlib.metadata
import
version
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
,
Optional
from
contextlib
import
contextmanager
...
...
@@ -16,26 +13,22 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch
import
TransformerLayer
,
fp8_autocast
,
fp8_model_init
from
transformer_engine.pytorch.attention
import
(
from
transformer_engine.pytorch.attention
.dot_product_attention
import
(
DotProductAttention
,
MultiheadAttention
,
_attention_backends
,
)
from
transformer_engine.pytorch.dot_product_attention.utils
import
(
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
FlashAttentionUtils
,
get_attention_backend
,
check_set_window_size
,
AttentionParams
,
)
from
transformer_engine.pytorch.dot_product_attention.inference
import
InferenceParams
from
transformer_engine.pytorch.dot_product_attention.rope
import
RotaryPositionEmbedding
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.attention
import
InferenceParams
from
transformer_engine.pytorch.attention
import
RotaryPositionEmbedding
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
(
AttnBiasType
,
AttnMaskType
,
FusedAttnBackend
,
QKVLayout
,
fused_attn_bwd
,
fused_attn_fwd
,
)
...
...
@@ -50,9 +43,7 @@ from transformer_engine.pytorch.utils import (
)
from
transformer_engine.pytorch.utils
import
get_cudnn_version
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
NVTE_Fused_Attn_Backend
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment