Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f8c2af4c
Commit
f8c2af4c
authored
May 21, 2025
by
yuguo
Browse files
Merge commit '
1d903f5e
' of...
Merge commit '
1d903f5e
' of
https://github.com/NVIDIA/TransformerEngine
parents
e92773a3
1d903f5e
Changes
211
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
548 additions
and
362 deletions
+548
-362
tests/cpp/operator/test_normalization.cu
tests/cpp/operator/test_normalization.cu
+10
-10
tests/cpp/operator/test_normalization_mxfp8.cu
tests/cpp/operator/test_normalization_mxfp8.cu
+7
-7
tests/cpp/operator/test_qdq.cu
tests/cpp/operator/test_qdq.cu
+4
-4
tests/cpp/operator/test_transpose.cu
tests/cpp/operator/test_transpose.cu
+2
-2
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+6
-11
tests/cpp/test_common.h
tests/cpp/test_common.h
+2
-0
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+138
-35
tests/jax/test_distributed_layernorm.py
tests/jax/test_distributed_layernorm.py
+3
-0
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+1
-27
tests/jax/test_helper.py
tests/jax/test_helper.py
+78
-46
tests/jax/utils.py
tests/jax/utils.py
+49
-41
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
+27
-14
tests/pytorch/distributed/run_gemm_with_overlap.py
tests/pytorch/distributed/run_gemm_with_overlap.py
+79
-29
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+9
-2
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+1
-1
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
+7
-2
tests/pytorch/distributed/test_comm_gemm_overlap.py
tests/pytorch/distributed/test_comm_gemm_overlap.py
+44
-73
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+65
-37
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
+6
-2
tests/pytorch/fused_attn/test_fused_attn.py
tests/pytorch/fused_attn/test_fused_attn.py
+10
-19
No files found.
tests/cpp/operator/test_normalization.cu
View file @
f8c2af4c
...
@@ -49,16 +49,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
...
@@ -49,16 +49,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return
;
return
;
}
}
Tensor
input
(
"input"
,
{
N
,
H
},
itype
);
Tensor
input
(
"input"
,
std
::
vector
<
size_t
>
{
N
,
H
},
itype
);
Tensor
z
(
"z"
,
{
N
,
H
},
otype
);
Tensor
z
(
"z"
,
std
::
vector
<
size_t
>
{
N
,
H
},
otype
);
Tensor
gamma
(
"gamma"
,
{
H
},
wtype
);
Tensor
gamma
(
"gamma"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
beta
(
"beta"
,
{
H
},
wtype
);
Tensor
beta
(
"beta"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
mu
(
"mu"
,
{
N
},
DType
::
kFloat32
);
Tensor
mu
(
"mu"
,
std
::
vector
<
size_t
>
{
N
},
DType
::
kFloat32
);
Tensor
rsigma
(
"rsigma"
,
{
N
},
DType
::
kFloat32
);
Tensor
rsigma
(
"rsigma"
,
std
::
vector
<
size_t
>
{
N
},
DType
::
kFloat32
);
Tensor
dz
(
"dz"
,
{
N
,
H
},
wtype
);
Tensor
dz
(
"dz"
,
std
::
vector
<
size_t
>
{
N
,
H
},
wtype
);
Tensor
dx
(
"dx"
,
{
N
,
H
},
itype
);
Tensor
dx
(
"dx"
,
std
::
vector
<
size_t
>
{
N
,
H
},
itype
);
Tensor
dgamma
(
"dgamma"
,
{
H
},
wtype
);
Tensor
dgamma
(
"dgamma"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
dbeta
(
"dbeta"
,
{
H
},
wtype
);
Tensor
dbeta
(
"dbeta"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
workspace_fwd
,
workspace_bwd
;
Tensor
workspace_fwd
,
workspace_bwd
;
fillUniform
(
&
input
);
fillUniform
(
&
input
);
...
...
tests/cpp/operator/test_normalization_mxfp8.cu
View file @
f8c2af4c
...
@@ -116,12 +116,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
...
@@ -116,12 +116,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
DType
wtype
=
TypeInfo
<
WeightType
>::
dtype
;
DType
wtype
=
TypeInfo
<
WeightType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
Tensor
input
(
"input"
,
{
N
,
H
},
itype
);
Tensor
input
(
"input"
,
std
::
vector
<
size_t
>
{
N
,
H
},
itype
);
Tensor
z
(
"z"
,
{
N
,
H
},
otype
,
true
,
is_training
,
NVTE_MXFP8_1D_SCALING
);
Tensor
z
(
"z"
,
std
::
vector
<
size_t
>
{
N
,
H
},
otype
,
true
,
is_training
,
NVTE_MXFP8_1D_SCALING
);
Tensor
gamma
(
"gamma"
,
{
H
},
wtype
);
Tensor
gamma
(
"gamma"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
beta
(
"beta"
,
{
H
},
wtype
);
Tensor
beta
(
"beta"
,
std
::
vector
<
size_t
>
{
H
},
wtype
);
Tensor
mu
(
"mu"
,
{
N
},
DType
::
kFloat32
);
Tensor
mu
(
"mu"
,
std
::
vector
<
size_t
>
{
N
},
DType
::
kFloat32
);
Tensor
rsigma
(
"rsigma"
,
{
N
},
DType
::
kFloat32
);
Tensor
rsigma
(
"rsigma"
,
std
::
vector
<
size_t
>
{
N
},
DType
::
kFloat32
);
Tensor
workspace
;
Tensor
workspace
;
...
@@ -164,7 +164,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
...
@@ -164,7 +164,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
}
}
Tensor
dequantized_output
(
"dequantized_output"
,
{
N
,
H
},
DType
::
kFloat32
,
true
,
true
);
Tensor
dequantized_output
(
"dequantized_output"
,
std
::
vector
<
size_t
>
{
N
,
H
},
DType
::
kFloat32
,
true
,
true
);
dequantize_2x
<
OutputType
,
fp8e8m0
>
(
z
,
dequantized_output
,
is_training
);
dequantize_2x
<
OutputType
,
fp8e8m0
>
(
z
,
dequantized_output
,
is_training
);
...
...
tests/cpp/operator/test_qdq.cu
View file @
f8c2af4c
...
@@ -58,8 +58,8 @@ void performTestQ(const size_t N) {
...
@@ -58,8 +58,8 @@ void performTestQ(const size_t N) {
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
Tensor
input
(
"input"
,
{
N
},
itype
);
Tensor
input
(
"input"
,
std
::
vector
<
size_t
>
{
N
},
itype
);
Tensor
output
(
"output"
,
{
N
},
otype
);
Tensor
output
(
"output"
,
std
::
vector
<
size_t
>
{
N
},
otype
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output
=
std
::
make_unique
<
OutputType
[]
>
(
N
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output
=
std
::
make_unique
<
OutputType
[]
>
(
N
);
...
@@ -89,8 +89,8 @@ void performTestDQ(const size_t N) {
...
@@ -89,8 +89,8 @@ void performTestDQ(const size_t N) {
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
Tensor
input
(
"input"
,
{
N
},
itype
);
Tensor
input
(
"input"
,
std
::
vector
<
size_t
>
{
N
},
itype
);
Tensor
output
(
"output"
,
{
N
},
otype
);
Tensor
output
(
"output"
,
std
::
vector
<
size_t
>
{
N
},
otype
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output
=
std
::
make_unique
<
OutputType
[]
>
(
N
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output
=
std
::
make_unique
<
OutputType
[]
>
(
N
);
...
...
tests/cpp/operator/test_transpose.cu
View file @
f8c2af4c
...
@@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) {
...
@@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) {
DType
dtype
=
TypeInfo
<
Type
>::
dtype
;
DType
dtype
=
TypeInfo
<
Type
>::
dtype
;
Tensor
input
(
"input"
,
{
N
,
H
},
dtype
);
Tensor
input
(
"input"
,
std
::
vector
<
size_t
>
{
N
,
H
},
dtype
);
Tensor
output
(
"output"
,
{
H
,
N
},
dtype
);
Tensor
output
(
"output"
,
std
::
vector
<
size_t
>
{
H
,
N
},
dtype
);
std
::
unique_ptr
<
Type
[]
>
ref_output
=
std
::
make_unique
<
Type
[]
>
(
N
*
H
);
std
::
unique_ptr
<
Type
[]
>
ref_output
=
std
::
make_unique
<
Type
[]
>
(
N
*
H
);
...
...
tests/cpp/test_common.cu
View file @
f8c2af4c
...
@@ -783,8 +783,6 @@ void fillUniform(Tensor *t) {
...
@@ -783,8 +783,6 @@ void fillUniform(Tensor *t) {
template
<
typename
InputEncoding
,
InputsFillCase
Case
>
template
<
typename
InputEncoding
,
InputsFillCase
Case
>
void
fillCase_special
(
Tensor
*
t
)
{
void
fillCase_special
(
Tensor
*
t
)
{
const
size_t
size
=
product
(
t
->
rowwise_shape
());
const
size_t
size
=
product
(
t
->
rowwise_shape
());
const
size_t
rows
=
t
->
rowwise_shape
().
data
[
0
];
const
size_t
cols
=
t
->
rowwise_shape
().
data
[
1
];
if
constexpr
(
Case
==
InputsFillCase
::
zeros
)
{
if
constexpr
(
Case
==
InputsFillCase
::
zeros
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
t
->
dtype
(),
InputType
,
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
t
->
dtype
(),
InputType
,
{
...
@@ -804,16 +802,13 @@ void fillCase_special(Tensor *t) {
...
@@ -804,16 +802,13 @@ void fillCase_special(Tensor *t) {
std
::
uniform_real_distribution
<>
dis_sign
(
-
1.0
,
1.0
);
std
::
uniform_real_distribution
<>
dis_sign
(
-
1.0
,
1.0
);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
t
->
dtype
(),
InputType
,
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
t
->
dtype
(),
InputType
,
{
InputType
*
data
=
t
->
rowwise_cpu_dptr
<
InputType
>
();
InputType
*
data
=
t
->
rowwise_cpu_dptr
<
InputType
>
();
for
(
size_t
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
size_t
idx
=
0
;
idx
<
size
;
++
idx
)
{
for
(
size_t
j
=
0
;
j
<
cols
;
++
j
)
{
const
bool
is_negative
=
(
dis_sign
(
t
->
gen
())
<
0.0
);
const
size_t
idx
=
i
*
cols
+
j
;
double
val
=
dis
(
t
->
gen
());
const
bool
is_negative
=
(
dis_sign
(
t
->
gen
())
<
0.0
);
if
(
is_negative
)
{
double
val
=
dis
(
t
->
gen
());
val
=
-
val
;
if
(
is_negative
)
{
val
=
-
val
;
}
data
[
idx
]
=
static_cast
<
InputType
>
(
val
);
}
}
data
[
idx
]
=
static_cast
<
InputType
>
(
val
);
}
}
});
});
}
}
...
...
tests/cpp/test_common.h
View file @
f8c2af4c
...
@@ -52,6 +52,7 @@ struct BytesToType<8> {
...
@@ -52,6 +52,7 @@ struct BytesToType<8> {
};
};
using
byte
=
uint8_t
;
using
byte
=
uint8_t
;
using
int16
=
int16_t
;
using
int32
=
int32_t
;
using
int32
=
int32_t
;
using
int64
=
int64_t
;
using
int64
=
int64_t
;
using
fp32
=
float
;
using
fp32
=
float
;
...
@@ -70,6 +71,7 @@ using fp8e8m0 = uint8_t;
...
@@ -70,6 +71,7 @@ using fp8e8m0 = uint8_t;
template
<
typename
T
>
template
<
typename
T
>
struct
TypeInfo
{
struct
TypeInfo
{
using
types
=
std
::
tuple
<
byte
,
using
types
=
std
::
tuple
<
byte
,
int16
,
int32
,
int32
,
int64
,
int64
,
fp32
,
fp32
,
...
...
tests/jax/test_custom_call_compute.py
View file @
f8c2af4c
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
numpy
as
np
import
pytest
import
pytest
from
jax
import
jit
,
value_and_grad
from
jax
import
jit
,
value_and_grad
from
functools
import
reduce
from
functools
import
reduce
...
@@ -18,11 +19,16 @@ from transformer_engine.jax.layernorm import layernorm
...
@@ -18,11 +19,16 @@ from transformer_engine.jax.layernorm import layernorm
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
from
transformer_engine.jax.cpp_extensions.activation
import
_jax_act_lu
,
_jax_quantize_dact_dbias
from
transformer_engine.jax.cpp_extensions.activation
import
_jax_act_lu
,
_jax_quantize_dact_dbias
from
transformer_engine.jax.cpp_extensions.normalization
import
_jax_layernorm
,
_jax_rmsnorm
from
transformer_engine.jax.cpp_extensions.normalization
import
(
_jax_layernorm
,
_jax_rmsnorm
,
is_norm_zero_centered_gamma_in_weight_dtype
,
)
from
transformer_engine.jax.cpp_extensions.quantization
import
(
from
transformer_engine.jax.cpp_extensions.quantization
import
(
_jax_quantize
,
_jax_quantize
,
_jax_quantize_dbias
,
_jax_quantize_dbias
,
)
)
from
transformer_engine.jax.cpp_extensions.misc
import
get_cudnn_version
from
transformer_engine.jax
import
cpp_extensions
as
tex
from
transformer_engine.jax
import
cpp_extensions
as
tex
from
transformer_engine.jax.quantize
import
(
from
transformer_engine.jax.quantize
import
(
DelayedScaleQuantizer
,
DelayedScaleQuantizer
,
...
@@ -33,7 +39,7 @@ from transformer_engine.jax.quantize import (
...
@@ -33,7 +39,7 @@ from transformer_engine.jax.quantize import (
)
)
from
transformer_engine.jax.quantize
import
helper
from
transformer_engine.jax.quantize
import
helper
from
transformer_engine.jax.activation
import
activation
from
transformer_engine.jax.activation
import
activation
from
transformer_engine.jax.dense
import
dense
,
grouped_dense
from
transformer_engine.jax.dense
import
dense
from
transformer_engine.jax.layernorm_dense
import
layernorm_dense
from
transformer_engine.jax.layernorm_dense
import
layernorm_dense
from
transformer_engine.jax.quantize
import
ScaledTensor1x
,
ScaledTensor2x
from
transformer_engine.jax.quantize
import
ScaledTensor1x
,
ScaledTensor2x
...
@@ -54,6 +60,7 @@ supported_scaling_modes = []
...
@@ -54,6 +60,7 @@ supported_scaling_modes = []
""" Find supported scaling modes"""
""" Find supported scaling modes"""
if
is_fp8_supported
:
if
is_fp8_supported
:
supported_scaling_modes
.
append
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
supported_scaling_modes
.
append
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
supported_scaling_modes
.
append
(
ScalingMode
.
CURRENT_TENSOR_SCALING
)
if
is_mxfp8_supported
:
if
is_mxfp8_supported
:
supported_scaling_modes
.
append
(
ScalingMode
.
MXFP8_1D_SCALING
)
supported_scaling_modes
.
append
(
ScalingMode
.
MXFP8_1D_SCALING
)
...
@@ -71,8 +78,19 @@ def is_shape_supported_by_mxfp8(input_shape):
...
@@ -71,8 +78,19 @@ def is_shape_supported_by_mxfp8(input_shape):
def
assert_bitwise_scaled_tensors
(
a
:
ScaledTensor
,
b
:
ScaledTensor
):
def
assert_bitwise_scaled_tensors
(
a
:
ScaledTensor
,
b
:
ScaledTensor
):
if
isinstance
(
a
,
ScaledTensor1x
)
and
isinstance
(
b
,
ScaledTensor1x
):
if
isinstance
(
a
,
ScaledTensor1x
)
and
isinstance
(
b
,
ScaledTensor1x
):
assert
a
.
scaling_mode
==
b
.
scaling_mode
assert
a
.
scale_inv
.
dtype
==
b
.
scale_inv
.
dtype
if
a
.
scaling_mode
.
is_tensor_scaling
():
# Assert in dq_dtype as some unfused codepaths have an intermediate cast
# to an input dtype which reduces precision compared to everything in fp32
assert_allclose
(
a
.
scale_inv
,
b
.
scale_inv
,
dtype
=
a
.
dq_dtype
)
elif
a
.
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
# Compare MXFP8 scales as uint8
assert_allclose
(
a
.
scale_inv
.
astype
(
jnp
.
uint8
),
b
.
scale_inv
.
astype
(
jnp
.
uint8
))
else
:
raise
ValueError
(
f
"Unsupported scaling mode
{
a
.
scaling_mode
}
"
)
assert_allclose
(
a
.
data
,
b
.
data
)
assert_allclose
(
a
.
data
,
b
.
data
)
assert_allclose
(
a
.
scale_inv
.
astype
(
jnp
.
uint8
),
b
.
scale_inv
.
astype
(
jnp
.
uint8
))
elif
isinstance
(
a
,
ScaledTensor2x
)
and
isinstance
(
b
,
ScaledTensor2x
):
elif
isinstance
(
a
,
ScaledTensor2x
)
and
isinstance
(
b
,
ScaledTensor2x
):
assert_bitwise_scaled_tensors
(
a
.
rowwise_tensor
,
b
.
rowwise_tensor
)
assert_bitwise_scaled_tensors
(
a
.
rowwise_tensor
,
b
.
rowwise_tensor
)
assert_bitwise_scaled_tensors
(
a
.
colwise_tensor
,
b
.
colwise_tensor
)
assert_bitwise_scaled_tensors
(
a
.
colwise_tensor
,
b
.
colwise_tensor
)
...
@@ -159,7 +177,12 @@ class TestActivation:
...
@@ -159,7 +177,12 @@ class TestActivation:
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
def
test_act_grad_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
):
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
[
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
]
)
def
test_act_grad_with_tensor_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
scaling_mode
):
x
=
random_inputs
x
=
random_inputs
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
...
@@ -170,7 +193,7 @@ class TestActivation:
...
@@ -170,7 +193,7 @@ class TestActivation:
)
)
quantizer
=
QuantizerFactory
.
create
(
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
S
caling
M
ode
.
DELAYED_TENSOR_SCALING
,
scaling_mode
=
s
caling
_m
ode
,
q_dtype
=
output_type
,
q_dtype
=
output_type
,
q_layout
=
QuantizeLayout
.
ROWWISE
,
q_layout
=
QuantizeLayout
.
ROWWISE
,
)
)
...
@@ -188,8 +211,11 @@ class TestActivation:
...
@@ -188,8 +211,11 @@ class TestActivation:
@
pytest_parametrize_wrapper
(
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
)
def
test_act_forward_with_delayed_scaling_fp8
(
@
pytest_parametrize_wrapper
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_layout
"scaling_mode"
,
[
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
]
)
def
test_act_forward_with_tensor_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_layout
,
scaling_mode
):
):
x
=
random_inputs
x
=
random_inputs
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
...
@@ -198,7 +224,7 @@ class TestActivation:
...
@@ -198,7 +224,7 @@ class TestActivation:
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
n_quantizers
=
2
,
scaling_mode
=
S
caling
M
ode
.
DELAYED_TENSOR_SCALING
,
scaling_mode
=
s
caling
_m
ode
,
q_dtype
=
output_type
,
q_dtype
=
output_type
,
q_layout
=
q_layout
,
q_layout
=
q_layout
,
)
)
...
@@ -335,8 +361,20 @@ class TestNorm:
...
@@ -335,8 +361,20 @@ class TestNorm:
@
pytest_parametrize_wrapper
(
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
)
def
test_norm_grad_with_delayed_scaling_fp8
(
@
pytest_parametrize_wrapper
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_layout
"scaling_mode"
,
[
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
]
)
def
test_norm_grad_with_tensor_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_layout
,
scaling_mode
,
):
):
"""
"""
Test transformer_engine.jax.layernorm.layernorm
Test transformer_engine.jax.layernorm.layernorm
...
@@ -345,9 +383,7 @@ class TestNorm:
...
@@ -345,9 +383,7 @@ class TestNorm:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
quantizer
=
QuantizerFactory
.
create
(
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
scaling_mode
=
scaling_mode
,
q_dtype
=
out_dtype
,
q_layout
=
q_layout
q_dtype
=
out_dtype
,
q_layout
=
q_layout
,
)
)
self
.
_test_norm_grad
(
self
.
_test_norm_grad
(
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
...
@@ -395,7 +431,41 @@ class TestNorm:
...
@@ -395,7 +431,41 @@ class TestNorm:
)
)
ref_mu
=
None
ref_mu
=
None
assert_bitwise_scaled_tensors
(
output
,
ref_out
)
precise_comparison
=
True
if
get_cudnn_version
()
<
(
9
,
10
,
0
)
and
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
# Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
# do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
precise_comparison
=
False
elif
is_norm_zero_centered_gamma_in_weight_dtype
(
scaling_mode
):
# Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
# for zero-centered gamma always
precise_comparison
=
False
elif
scaling_mode
==
ScalingMode
.
CURRENT_TENSOR_SCALING
and
inp_dtype
!=
jnp
.
float32
:
# Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
# and writes intermediate results into the input dtype, which will slightly reduce precision
# if the input dtype is not float32
precise_comparison
=
False
if
precise_comparison
:
assert_bitwise_scaled_tensors
(
output
,
ref_out
)
else
:
if
isinstance
(
ref_out
,
ScaledTensor1x
):
assert_allclose
(
output
.
dequantize
(),
ref_out
.
dequantize
(),
dtype
=
out_dtype
)
elif
isinstance
(
ref_out
,
ScaledTensor2x
):
assert_allclose
(
output
.
rowwise_tensor
.
dequantize
(),
ref_out
.
rowwise_tensor
.
dequantize
(),
dtype
=
out_dtype
,
)
assert_allclose
(
output
.
colwise_tensor
.
dequantize
(),
ref_out
.
colwise_tensor
.
dequantize
(),
dtype
=
out_dtype
,
)
else
:
pytest
.
fail
(
"Unsupported output type"
)
assert_allclose
(
rsigma
,
ref_rsigma
,
dtype
=
inp_dtype
)
assert_allclose
(
rsigma
,
ref_rsigma
,
dtype
=
inp_dtype
)
if
norm_type
==
"layernorm"
:
if
norm_type
==
"layernorm"
:
assert_allclose
(
mu
,
ref_mu
,
dtype
=
inp_dtype
)
assert_allclose
(
mu
,
ref_mu
,
dtype
=
inp_dtype
)
...
@@ -406,8 +476,20 @@ class TestNorm:
...
@@ -406,8 +476,20 @@ class TestNorm:
@
pytest_parametrize_wrapper
(
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
)
def
test_norm_forward_with_delayed_scaling_fp8
(
@
pytest_parametrize_wrapper
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_layout
"scaling_mode"
,
[
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
]
)
def
test_norm_forward_with_tensor_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_layout
,
scaling_mode
,
):
):
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
...
@@ -420,7 +502,7 @@ class TestNorm:
...
@@ -420,7 +502,7 @@ class TestNorm:
epsilon
=
epsilon
,
epsilon
=
epsilon
,
inp_dtype
=
inp_dtype
,
inp_dtype
=
inp_dtype
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
S
caling
M
ode
.
DELAYED_TENSOR_SCALING
,
scaling_mode
=
s
caling
_m
ode
,
q_layout
=
q_layout
,
q_layout
=
q_layout
,
)
)
...
@@ -447,17 +529,24 @@ QUANTIZE_OUTPUT_DTYPES = {
...
@@ -447,17 +529,24 @@ QUANTIZE_OUTPUT_DTYPES = {
"L2"
:
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
],
"L2"
:
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
],
}
}
ALL_QUANTIZE_TEST_SHAPES
=
[
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES
=
[
(
32
,
64
),
((
32
,
64
),
-
1
),
(
2
,
64
,
32
),
((
2
,
64
,
32
),
-
1
),
((
2
,
64
,
32
),
-
2
),
((
32
,
256
,
128
),
-
1
),
((
32
,
256
,
128
),
-
2
),
((
64
,
32
,
32
,
256
),
-
1
),
((
64
,
32
,
32
,
256
),
-
2
),
((
64
,
32
,
32
,
256
),
-
3
),
]
]
QUANTIZE_TEST_SHAPES
=
{
QUANTIZE_TEST_SHAPES
_AND_FLATTEN_AXES
=
{
"L0"
:
[
"L0"
:
[
(
32
,
256
,
128
),
((
32
,
64
),
-
1
),
(
64
,
32
,
32
,
256
),
((
2
,
64
,
32
),
-
1
),
((
2
,
64
,
32
),
-
2
),
],
],
"L2"
:
ALL_QUANTIZE_TEST_SHAPES
,
"L2"
:
ALL_QUANTIZE_TEST_SHAPES
_AND_FLATTEN_AXES
,
}
}
QUANTIZATION_INPUT_DTYPE
=
{
QUANTIZATION_INPUT_DTYPE
=
{
...
@@ -469,9 +558,8 @@ QUANTIZATION_INPUT_DTYPE = {
...
@@ -469,9 +558,8 @@ QUANTIZATION_INPUT_DTYPE = {
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"input_shape
,flatten_axis
"
,
ALL_QUANTIZE_TEST_SHAPES
_AND_FLATTEN_AXES
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"flatten_axis"
,
[
-
1
,
-
2
])
@
pytest_parametrize_wrapper
(
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
COLWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
COLWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
)
...
@@ -524,12 +612,11 @@ class TestFusedQuantize:
...
@@ -524,12 +612,11 @@ class TestFusedQuantize:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"input_shape
,flatten_axis
"
,
QUANTIZE_TEST_SHAPES
_AND_FLATTEN_AXES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
)
@
pytest_parametrize_wrapper
(
"flatten_axis"
,
[
-
1
,
-
2
])
def
test_quantize_dbias
(
def
test_quantize_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
q_layout
,
flatten_axis
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
q_layout
,
flatten_axis
):
):
...
@@ -538,6 +625,12 @@ class TestFusedQuantize:
...
@@ -538,6 +625,12 @@ class TestFusedQuantize:
):
):
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
if
(
flatten_axis
<
0
and
flatten_axis
+
len
(
input_shape
)
<=
0
)
or
flatten_axis
<=
0
:
pytest
.
skip
(
f
"Flatten axis
{
flatten_axis
}
is not supported for input shape
{
input_shape
}
. There"
" must be at least one axis on either side of the flatten_axis split."
)
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
...
@@ -630,16 +723,19 @@ class TestFusedQuantize:
...
@@ -630,16 +723,19 @@ class TestFusedQuantize:
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
COL
WISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
"q_layout"
,
[
QuantizeLayout
.
ROW
WISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
)
def
test_quantize_dact_dbias_delayed_scaling
(
@
pytest_parametrize_wrapper
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_layout
"scaling_mode"
,
[
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
]
)
def
test_quantize_dact_dbias_tensor_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_layout
,
scaling_mode
):
):
self
.
_test_quantize_dact_dbias
(
self
.
_test_quantize_dact_dbias
(
in_dtype
=
in_dtype
,
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
input_shape
=
input_shape
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
S
caling
M
ode
.
DELAYED_TENSOR_SCALING
,
scaling_mode
=
s
caling
_m
ode
,
activation_type
=
activation_type
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
q_layout
=
q_layout
,
q_layout
=
q_layout
,
...
@@ -830,7 +926,10 @@ class TestFusedDense:
...
@@ -830,7 +926,10 @@ class TestFusedDense:
Test layernorm_dense VJP Rule
Test layernorm_dense VJP Rule
"""
"""
# No Norm FWD E5M2 in TE backend
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
in
(
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
,
):
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
# zero_centered_gamma is already tested in TestNorm
...
@@ -916,7 +1015,10 @@ class TestFusedDense:
...
@@ -916,7 +1015,10 @@ class TestFusedDense:
Test layernorm_mlp VJP Rule
Test layernorm_mlp VJP Rule
"""
"""
# No Norm FWD E5M2 in TE backend
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
in
(
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
,
):
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
# zero_centered_gamma is already tested in TestNorm
...
@@ -1052,7 +1154,7 @@ fwd_bwd_dtypes = [
...
@@ -1052,7 +1154,7 @@ fwd_bwd_dtypes = [
[
jnp
.
float8_e5m2
,
jnp
.
float8_e4m3fn
],
[
jnp
.
float8_e5m2
,
jnp
.
float8_e4m3fn
],
]
]
"""
@pytest_parametrize_wrapper(
@pytest_parametrize_wrapper(
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
)
)
...
@@ -1267,3 +1369,4 @@ class TestGroupedDense:
...
@@ -1267,3 +1369,4 @@ class TestGroupedDense:
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
tests/jax/test_distributed_layernorm.py
View file @
f8c2af4c
...
@@ -34,6 +34,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
...
@@ -34,6 +34,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES
=
[]
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
if
is_fp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
Float8CurrentScaling
(),
id
=
"CurrentScaling"
))
if
is_mxfp8_supported
:
if
is_mxfp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
...
@@ -76,6 +77,8 @@ class TestDistributedLayernorm:
...
@@ -76,6 +77,8 @@ class TestDistributedLayernorm:
other_bytes
=
0
other_bytes
=
0
if
fp8_recipe
==
recipe
.
MXFP8BlockScaling
()
and
"dp"
in
mesh_axes
:
if
fp8_recipe
==
recipe
.
MXFP8BlockScaling
()
and
"dp"
in
mesh_axes
:
other_bytes
=
384
# required for small scale shapes that require padding
other_bytes
=
384
# required for small scale shapes that require padding
if
fp8_recipe
==
recipe
.
Float8CurrentScaling
():
allreduce_total_bytes
+=
jax_dtype
.
itemsize
# 1 * dtype for the amax reduction
return
generate_collectives_count
(
return
generate_collectives_count
(
allreduce
=
allreduce_total_bytes
*
int
(
is_dp_enabled
),
allgather
=
0
,
other
=
other_bytes
allreduce
=
allreduce_total_bytes
*
int
(
is_dp_enabled
),
allgather
=
0
,
other
=
other_bytes
)
)
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
f8c2af4c
...
@@ -41,6 +41,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
...
@@ -41,6 +41,7 @@ is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
SUPPORTED_RECIPES
=
[]
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
if
is_fp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
DelayedScaling
(),
id
=
"DelayedScaling"
))
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
Float8CurrentScaling
(),
id
=
"CurrentScaling"
))
if
is_mxfp8_supported
:
if
is_mxfp8_supported
:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
...
@@ -217,37 +218,10 @@ class TestDistributedLayernormMLP:
...
@@ -217,37 +218,10 @@ class TestDistributedLayernormMLP:
m_grad
,
s_grad
,
dtype
=
dtype
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
m_grad
,
s_grad
,
dtype
=
dtype
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
)
)
else
:
else
:
is_gated
=
len
(
activation_type
)
>
1
rtol
=
None
atol
=
None
if
is_gated
:
if
dtype
==
jnp
.
bfloat16
:
if
i
==
2
:
rtol
=
800
atol
=
9e-2
if
i
==
4
:
atol
=
300
rtol
=
1e-1
if
dtype
==
jnp
.
float16
:
if
i
==
1
:
# gamma
rtol
=
200
atol
=
1e-2
if
i
==
2
:
rtol
=
2000
atol
=
7e-2
if
i
==
4
and
fp8_recipe
==
recipe
.
MXFP8BlockScaling
():
# bias_1
# Accumulating dbias across a large tensor introduces a larger difference
rtol
=
200
atol
=
4e-2
if
i
==
4
and
fp8_recipe
==
recipe
.
DelayedScaling
():
rtol
=
2200
atol
=
9e-2
assert_allclose
(
assert_allclose
(
multi_grads
[
i
],
multi_grads
[
i
],
single_grads
[
i
],
single_grads
[
i
],
dtype
=
dtype
,
dtype
=
dtype
,
rtol
=
rtol
,
atol
=
atol
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
)
...
...
tests/jax/test_helper.py
View file @
f8c2af4c
...
@@ -10,47 +10,22 @@ import jax.numpy as jnp
...
@@ -10,47 +10,22 @@ import jax.numpy as jnp
import
numpy
as
np
import
numpy
as
np
from
utils
import
assert_allclose
from
utils
import
assert_allclose
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.common.recipe
import
DelayedScaling
,
MXFP8BlockScaling
,
Float8CurrentScaling
from
transformer_engine.common.recipe
import
Format
as
FP8Format
from
transformer_engine.common.recipe
import
Format
as
FP8Format
from
transformer_engine.jax
import
fp8_autocast
,
get_delayed_scaling
from
transformer_engine.jax
import
fp8_autocast
,
get_delayed_scaling
from
transformer_engine.jax.quantize
import
QuantizeConfig
,
is_fp8_available
,
AmaxComputeAlgo
from
transformer_engine.jax.quantize
import
(
QuantizeConfig
,
is_fp8_available
,
ScalingMode
,
update_collections
,
)
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
from
transformer_engine.jax.sharding
import
MeshResource
,
global_mesh_resource
is_fp8_supported
,
reason
=
is_fp8_available
()
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
mxfp8_reason
=
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
class
TestQuantizeConfig
(
unittest
.
TestCase
):
class
TestHelper
(
unittest
.
TestCase
):
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_initialize
(
self
):
margin
=
5.0
fp8_format
=
FP8Format
.
E4M3
amax_history_len
=
10
QuantizeConfig
.
initialize
(
margin
=
margin
,
fp8_format
=
fp8_format
,
amax_history_len
=
amax_history_len
)
self
.
assertEqual
(
QuantizeConfig
.
MARGIN
,
margin
,
f
"QuantizeConfig.MARGIN initialization failed, should be
{
margin
}
"
f
" but got
{
QuantizeConfig
.
MARGIN
}
."
,
)
self
.
assertEqual
(
QuantizeConfig
.
FP8_FORMAT
,
fp8_format
,
f
"QuantizeConfig.FP8_FORMAT initialization failed, should be
{
fp8_format
}
"
f
" but got
{
QuantizeConfig
.
FP8_FORMAT
}
."
,
)
self
.
assertEqual
(
QuantizeConfig
.
AMAX_HISTORY_LEN
,
amax_history_len
,
f
"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be
{
amax_history_len
}
"
f
" but got
{
QuantizeConfig
.
AMAX_HISTORY_LEN
}
."
,
)
QuantizeConfig
.
finalize
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_update_collections
(
self
):
def
test_update_collections
(
self
):
...
@@ -61,19 +36,19 @@ class TestQuantizeConfig(unittest.TestCase):
...
@@ -61,19 +36,19 @@ class TestQuantizeConfig(unittest.TestCase):
"test1"
:
original_val
,
"test1"
:
original_val
,
"test2"
:
original_val
,
"test2"
:
original_val
,
}
}
updated_state
=
QuantizeConfig
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
updated_state
=
update_collections
({
"test1"
:
updated_val
},
original_state
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
original_state
=
flax
.
core
.
frozen_dict
.
FrozenDict
(
original_state
)
original_state
=
flax
.
core
.
frozen_dict
.
FrozenDict
(
original_state
)
updated_state
=
QuantizeConfig
.
update_collections
({
"test1"
:
updated_val
},
original_state
)
updated_state
=
update_collections
({
"test1"
:
updated_val
},
original_state
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test1"
],
updated_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
self
.
assertEqual
(
updated_state
[
"test2"
],
original_val
)
class
TestFP8Functions
(
unittest
.
TestCase
):
class
TestFP8Functions
(
unittest
.
TestCase
):
def
_check_defult_state
(
self
):
def
_check_def
a
ult_state
(
self
):
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
def
_compare_delay_scaling
(
self
,
ref
,
test
):
def
_compare_delay_scaling
(
self
,
ref
,
test
):
...
@@ -82,35 +57,92 @@ class TestFP8Functions(unittest.TestCase):
...
@@ -82,35 +57,92 @@ class TestFP8Functions(unittest.TestCase):
self
.
assertTrue
(
ref
.
amax_history_len
==
test
.
amax_history_len
)
self
.
assertTrue
(
ref
.
amax_history_len
==
test
.
amax_history_len
)
self
.
assertTrue
(
ref
.
amax_compute_algo
==
test
.
amax_compute_algo
)
self
.
assertTrue
(
ref
.
amax_compute_algo
==
test
.
amax_compute_algo
)
def
_compare_current_scaling
(
self
,
test
):
self
.
assertEqual
(
QuantizeConfig
.
MARGIN
,
test
.
margin
)
self
.
assertEqual
(
QuantizeConfig
.
FP8_FORMAT
,
test
.
fp8_format
)
self
.
assertEqual
(
QuantizeConfig
.
SCALING_MODE
,
ScalingMode
.
CURRENT_TENSOR_SCALING
)
def
_compare_mxfp8_scaling
(
self
,
test
):
self
.
assertEqual
(
QuantizeConfig
.
MARGIN
,
test
.
margin
)
self
.
assertEqual
(
QuantizeConfig
.
FP8_FORMAT
,
test
.
fp8_format
)
self
.
assertEqual
(
QuantizeConfig
.
SCALING_MODE
,
ScalingMode
.
MXFP8_1D_SCALING
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast
(
self
):
def
test_fp8_autocast
_delayed_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_defult_state
()
self
.
_check_def
a
ult_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
DelayedScaling
()):
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
DelayedScaling
()):
self
.
assertFalse
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_check_default_state
()
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
DelayedScaling
())
self
.
_check_defult_state
()
self
.
_check_def
a
ult_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_defult_state
()
self
.
_check_def
a
ult_state
()
ds
=
DelayedScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
,
amax_history_len
=
1
)
ds
=
DelayedScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
,
amax_history_len
=
1
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
ds
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_check_defult_state
()
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
def
test_fp8_autocast_mxfp8_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
Float8CurrentScaling
()):
self
.
_check_default_state
()
self
.
_check_default_state
()
cs
=
Float8CurrentScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
cs
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_current_scaling
(
cs
)
self
.
_check_default_state
()
cs
=
Float8CurrentScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
cs
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_current_scaling
(
cs
)
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
def
test_fp8_autocast_mxfp8_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
with
fp8_autocast
(
enabled
=
False
,
fp8_recipe
=
MXFP8BlockScaling
()):
self
.
_check_default_state
()
self
.
_check_default_state
()
bs
=
MXFP8BlockScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
bs
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_mxfp8_scaling
(
bs
)
self
.
_check_default_state
()
bs
=
MXFP8BlockScaling
(
margin
=
3.0
,
fp8_format
=
FP8Format
.
HYBRID
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
bs
):
self
.
assertTrue
(
QuantizeConfig
.
is_fp8_enabled
())
self
.
_compare_mxfp8_scaling
(
bs
)
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
@
unittest
.
skipIf
(
not
is_fp8_supported
,
reason
=
reason
)
def
test_fp8_autocast_with_sharding_resource
(
self
):
def
test_fp8_autocast_with_sharding_resource
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_defult_state
()
self
.
_check_def
a
ult_state
()
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
ds
=
DelayedScaling
(
margin
=
5.0
,
fp8_format
=
FP8Format
.
E4M3
,
amax_history_len
=
1
)
...
@@ -130,4 +162,4 @@ class TestFP8Functions(unittest.TestCase):
...
@@ -130,4 +162,4 @@ class TestFP8Functions(unittest.TestCase):
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
_compare_delay_scaling
(
get_delayed_scaling
(),
ds
)
self
.
assertEqual
(
sr
,
global_mesh_resource
())
self
.
assertEqual
(
sr
,
global_mesh_resource
())
self
.
_check_defult_state
()
self
.
_check_def
a
ult_state
()
tests/jax/utils.py
View file @
f8c2af4c
...
@@ -13,7 +13,6 @@ import jax
...
@@ -13,7 +13,6 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
import
numpy
as
np
import
numpy
as
np
from
flax
import
linen
as
nn
from
flax
import
linen
as
nn
from
flax.linen
import
partitioning
as
nn_partitioning
from
flax.linen.attention
import
combine_masks
from
flax.linen.attention
import
combine_masks
from
jax
import
lax
,
vmap
from
jax
import
lax
,
vmap
from
jax
import
nn
as
jax_nn
from
jax
import
nn
as
jax_nn
...
@@ -97,16 +96,16 @@ def combine_biases(*masks: Optional[Array]):
...
@@ -97,16 +96,16 @@ def combine_biases(*masks: Optional[Array]):
return
mask
return
mask
def
parameter
ize_by
_test_level
(
param_dict
:
dict
,
id_prefix
:
str
=
""
):
def
get_
parameter
s_for
_test_level
(
param_dict
:
dict
):
"""
"""
Takes an input dictionary of parameters keyed by test type "L0", etc.
Takes an input dictionary of parameters keyed by test type "L0", etc.
Returns
a list of pytest parameters to be used in a parameterized test for the current test typ
e
Returns
the parameters for the test level specified in the environment variabl
e
"""
"""
DEFAULT_TEST_LEVEL
=
"L0"
DEFAULT_TEST_LEVEL
=
"L0"
test_level
=
os
.
environ
.
get
(
"NVTE_JAX_UNITTEST_LEVEL"
,
DEFAULT_TEST_LEVEL
)
test_level
=
os
.
environ
.
get
(
"NVTE_JAX_UNITTEST_LEVEL"
,
DEFAULT_TEST_LEVEL
)
if
test_level
not
in
param_dict
:
if
test_level
not
in
param_dict
:
raise
ValueError
(
"Unsupported test level"
)
raise
ValueError
(
"Unsupported test level"
)
return
values_to_named_params
(
param_dict
[
test_level
]
,
id_prefix
)
return
param_dict
[
test_level
]
def
value_to_test_name_str
(
value
):
def
value_to_test_name_str
(
value
):
...
@@ -139,14 +138,18 @@ def pytest_parametrize_wrapper(param_name, param_values):
...
@@ -139,14 +138,18 @@ def pytest_parametrize_wrapper(param_name, param_values):
A wrapper for pytest.mark.parametrize to allow for automatic
A wrapper for pytest.mark.parametrize to allow for automatic
naming of tests based on the parameter values.
naming of tests based on the parameter values.
"""
"""
id_prefix
=
param_name
if
isinstance
(
param_values
,
dict
):
if
isinstance
(
param_values
,
dict
):
param_values
=
parameterize_by_test_level
(
param_values
,
id_prefix
=
param_name
)
# If the values are split into a dictionary of test-levels, e.g. "L0", etc.,
elif
","
not
in
param_name
:
# unwrap the selected level before proceeding.
param_values
=
values_to_named_params
(
param_values
,
id_prefix
=
id_prefix
)
param_values
=
get_parameters_for_test_level
(
param_values
)
if
","
not
in
param_name
:
# Multi-parameterize annotations are not supported in this wrapper
# and are just a passthrough to default pytest.mark.parametrize.
# E.g. @pytest_parametrize_wrapper("a,b", ((a_value1, b_value1), (a_value2, b_value2)))
# will be passed through to pytest.mark.parametrize as-is without pytest.param ids.
param_values
=
values_to_named_params
(
param_values
,
id_prefix
=
param_name
)
# Currently comma separated parameters in one parametrize call aren't supported for automatic naming
# and will just be passed through with default pytest names
def
decorator
(
func
):
def
decorator
(
func
):
return
pytest
.
mark
.
parametrize
(
param_name
,
param_values
)(
func
)
return
pytest
.
mark
.
parametrize
(
param_name
,
param_values
)(
func
)
...
@@ -312,16 +315,22 @@ class DenseGeneral(nn.Module):
...
@@ -312,16 +315,22 @@ class DenseGeneral(nn.Module):
kernel_shape
=
tuple
(
inputs
.
shape
[
ax
]
for
ax
in
axis
)
+
features
kernel_shape
=
tuple
(
inputs
.
shape
[
ax
]
for
ax
in
axis
)
+
features
kernel_param_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),
np
.
prod
(
features
))
kernel_param_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),
np
.
prod
(
features
))
kernel
=
nn_partitioning
.
param_with_axes
(
kernel
=
self
.
param
(
"kernel"
,
self
.
kernel_init
,
kernel_param_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
"kernel"
,
nn
.
with_logical_partitioning
(
self
.
kernel_init
,
self
.
kernel_axes
),
kernel_param_shape
,
self
.
dtype
,
)
)
kernel
=
jnp
.
asarray
(
kernel
,
input_dtype
)
kernel
=
jnp
.
asarray
(
kernel
,
input_dtype
)
kernel
=
jnp
.
reshape
(
kernel
,
kernel_shape
)
kernel
=
jnp
.
reshape
(
kernel
,
kernel_shape
)
if
self
.
use_bias
:
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
bias
=
self
.
param
(
"bias"
,
self
.
bias_init
,
self
.
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
"bias"
,
nn
.
with_logical_partitioning
(
self
.
bias_init
,
self
.
bias_axes
),
self
.
features
,
self
.
dtype
,
)
)
bias
=
bias
.
astype
(
input_dtype
)
bias
=
bias
.
astype
(
input_dtype
)
else
:
else
:
...
@@ -418,9 +427,9 @@ class MlpBlock(nn.Module):
...
@@ -418,9 +427,9 @@ class MlpBlock(nn.Module):
)
# Broadcast along length.
)
# Broadcast along length.
if
self
.
transpose_batch_sequence
:
if
self
.
transpose_batch_sequence
:
x
=
nn
_partitioning
.
with_sharding
_constraint
(
x
,
(
"length"
,
"batch"
,
"mlp"
))
x
=
nn
.
with_logical
_constraint
(
x
,
(
"length"
,
"batch"
,
"mlp"
))
else
:
else
:
x
=
nn
_partitioning
.
with_sharding
_constraint
(
x
,
(
"batch"
,
"length"
,
"mlp"
))
x
=
nn
.
with_logical
_constraint
(
x
,
(
"batch"
,
"length"
,
"mlp"
))
output
=
DenseGeneral
(
output
=
DenseGeneral
(
inputs
.
shape
[
-
1
],
inputs
.
shape
[
-
1
],
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
...
@@ -684,21 +693,13 @@ class MultiHeadAttention(nn.Module):
...
@@ -684,21 +693,13 @@ class MultiHeadAttention(nn.Module):
value
=
value
.
reshape
((
*
value
.
shape
[:
2
],
self
.
num_gqa_groups
,
self
.
head_dim
))
value
=
value
.
reshape
((
*
value
.
shape
[:
2
],
self
.
num_gqa_groups
,
self
.
head_dim
))
if
self
.
transpose_batch_sequence
:
if
self
.
transpose_batch_sequence
:
query
=
nn_partitioning
.
with_sharding_constraint
(
query
=
nn
.
with_logical_constraint
(
query
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
))
query
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
)
key
=
nn
.
with_logical_constraint
(
key
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
))
)
value
=
nn
.
with_logical_constraint
(
value
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
))
key
=
nn_partitioning
.
with_sharding_constraint
(
key
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
))
value
=
nn_partitioning
.
with_sharding_constraint
(
value
,
(
"length"
,
"batch"
,
"heads"
,
"kv"
)
)
else
:
else
:
query
=
nn_partitioning
.
with_sharding_constraint
(
query
=
nn
.
with_logical_constraint
(
query
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
))
query
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
)
key
=
nn
.
with_logical_constraint
(
key
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
))
)
value
=
nn
.
with_logical_constraint
(
value
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
))
key
=
nn_partitioning
.
with_sharding_constraint
(
key
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
))
value
=
nn_partitioning
.
with_sharding_constraint
(
value
,
(
"batch"
,
"length"
,
"heads"
,
"kv"
)
)
if
decode
:
if
decode
:
# Detect if we're initializing by absence of existing cache data.
# Detect if we're initializing by absence of existing cache data.
...
@@ -805,9 +806,9 @@ class MultiHeadAttention(nn.Module):
...
@@ -805,9 +806,9 @@ class MultiHeadAttention(nn.Module):
x
=
x
.
reshape
((
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
]
*
x
.
shape
[
3
]))
x
=
x
.
reshape
((
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
]
*
x
.
shape
[
3
]))
if
self
.
transpose_batch_sequence
:
if
self
.
transpose_batch_sequence
:
x
=
nn
_partitioning
.
with_sharding
_constraint
(
x
,
(
"length"
,
"batch"
,
"joined_kv"
))
x
=
nn
.
with_logical
_constraint
(
x
,
(
"length"
,
"batch"
,
"joined_kv"
))
else
:
else
:
x
=
nn
_partitioning
.
with_sharding
_constraint
(
x
,
(
"batch"
,
"length"
,
"joined_kv"
))
x
=
nn
.
with_logical
_constraint
(
x
,
(
"batch"
,
"length"
,
"joined_kv"
))
# Back to the original inputs dimensions.
# Back to the original inputs dimensions.
...
@@ -853,8 +854,11 @@ class LayerNorm(nn.Module):
...
@@ -853,8 +854,11 @@ class LayerNorm(nn.Module):
input_dtype
=
x
.
dtype
input_dtype
=
x
.
dtype
features
=
x
.
shape
[
-
1
]
features
=
x
.
shape
[
-
1
]
scale
=
nn_partitioning
.
param_with_axes
(
scale
=
self
.
param
(
"scale"
,
self
.
scale_init
,
(
features
,),
self
.
dtype
,
axes
=
(
"embed"
,)
"scale"
,
nn
.
with_logical_partitioning
(
self
.
scale_init
,
(
"embed"
,)),
(
features
,),
self
.
dtype
,
)
)
x_
=
x
.
astype
(
jnp
.
float32
)
x_
=
x
.
astype
(
jnp
.
float32
)
if
self
.
layernorm_type
==
"layernorm"
:
if
self
.
layernorm_type
==
"layernorm"
:
...
@@ -862,8 +866,11 @@ class LayerNorm(nn.Module):
...
@@ -862,8 +866,11 @@ class LayerNorm(nn.Module):
var
=
jnp
.
mean
(
jnp
.
square
(
x_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
y
=
(
x_
-
mean
)
*
lax
.
rsqrt
(
var
+
self
.
epsilon
)
y
=
(
x_
-
mean
)
*
lax
.
rsqrt
(
var
+
self
.
epsilon
)
bias
=
nn_partitioning
.
param_with_axes
(
bias
=
self
.
param
(
"ln_bias"
,
self
.
bias_init
,
(
features
,),
self
.
dtype
,
axes
=
(
"embed"
,)
"ln_bias"
,
nn
.
with_logical_partitioning
(
self
.
bias_init
,
(
"embed"
,)),
(
features
,),
self
.
dtype
,
)
)
bias
=
jnp
.
asarray
(
bias
,
input_dtype
)
bias
=
jnp
.
asarray
(
bias
,
input_dtype
)
...
@@ -972,12 +979,11 @@ class RelativePositionBiases(nn.Module):
...
@@ -972,12 +979,11 @@ class RelativePositionBiases(nn.Module):
num_buckets
=
self
.
num_buckets
,
num_buckets
=
self
.
num_buckets
,
max_distance
=
self
.
max_distance
,
max_distance
=
self
.
max_distance
,
)
)
relative_attention_bias
=
nn_partitioning
.
param_with_axes
(
relative_attention_bias
=
self
.
param
(
"rel_embedding"
,
"rel_embedding"
,
self
.
embedding_init
,
nn
.
with_logical_partitioning
(
self
.
embedding_init
,
(
"heads"
,
"relpos_buckets"
))
,
(
self
.
num_heads
,
self
.
num_buckets
),
(
self
.
num_heads
,
self
.
num_buckets
),
jnp
.
float32
,
jnp
.
float32
,
axes
=
(
"heads"
,
"relpos_buckets"
),
)
)
relative_attention_bias
=
jnp
.
asarray
(
relative_attention_bias
,
self
.
dtype
)
relative_attention_bias
=
jnp
.
asarray
(
relative_attention_bias
,
self
.
dtype
)
...
@@ -1555,14 +1561,16 @@ def sync_params_values(dst, src, transformations, sep="/"):
...
@@ -1555,14 +1561,16 @@ def sync_params_values(dst, src, transformations, sep="/"):
"""
"""
src_values
=
{}
src_values
=
{}
for
key
,
value
in
jax
.
tree_util
.
tree_leaves_with_path
(
src
):
for
key
,
value
in
jax
.
tree_util
.
tree_leaves_with_path
(
src
):
normalized_key
=
sep
.
join
(
x
.
key
for
x
in
key
)
# Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key
=
sep
.
join
(
x
.
key
for
x
in
key
if
hasattr
(
x
,
"key"
))
src_values
[
normalized_key
]
=
value
src_values
[
normalized_key
]
=
value
flatten_dst
,
dst_tree_def
=
jax
.
tree_util
.
tree_flatten_with_path
(
dst
)
flatten_dst
,
dst_tree_def
=
jax
.
tree_util
.
tree_flatten_with_path
(
dst
)
synced_dst_values
=
[]
synced_dst_values
=
[]
for
key
,
value
in
flatten_dst
:
for
key
,
value
in
flatten_dst
:
normalized_key
=
sep
.
join
(
x
.
key
for
x
in
key
)
# Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key
=
sep
.
join
(
x
.
key
for
x
in
key
if
hasattr
(
x
,
"key"
))
if
normalized_key
in
transformations
:
if
normalized_key
in
transformations
:
corresponding_src_key
=
transformations
[
normalized_key
]
corresponding_src_key
=
transformations
[
normalized_key
]
else
:
else
:
...
...
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
View file @
f8c2af4c
...
@@ -16,6 +16,7 @@ import torch.distributed as dist
...
@@ -16,6 +16,7 @@ import torch.distributed as dist
from
transformer_engine.common.recipe
import
(
from
transformer_engine.common.recipe
import
(
DelayedScaling
,
DelayedScaling
,
Float8CurrentScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
Format
,
Format
,
Recipe
,
Recipe
,
)
)
...
@@ -26,6 +27,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
...
@@ -26,6 +27,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer
,
Float8CurrentScalingQuantizer
,
)
)
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockwiseQTensor
def
_get_raw_data
(
quantized_tensor
):
def
_get_raw_data
(
quantized_tensor
):
...
@@ -34,6 +36,14 @@ def _get_raw_data(quantized_tensor):
...
@@ -34,6 +36,14 @@ def _get_raw_data(quantized_tensor):
assert
hasattr
(
quantized_tensor
,
"_data"
),
"Float8Tensor does not have _data attribute"
assert
hasattr
(
quantized_tensor
,
"_data"
),
"Float8Tensor does not have _data attribute"
assert
quantized_tensor
.
_data
.
dtype
==
torch
.
uint8
,
"Float8Tensor _data must be uint8"
assert
quantized_tensor
.
_data
.
dtype
==
torch
.
uint8
,
"Float8Tensor _data must be uint8"
return
quantized_tensor
.
_data
return
quantized_tensor
.
_data
elif
isinstance
(
quantized_tensor
,
Float8BlockwiseQTensor
):
assert
hasattr
(
quantized_tensor
,
"_rowwise_data"
),
"Float8BlockwiseQTensor does not have _rowwise_data attribute"
assert
(
quantized_tensor
.
_rowwise_data
.
dtype
==
torch
.
uint8
),
"Float8BlockwiseQTensor _rowwise_data must be uint8"
return
quantized_tensor
.
_rowwise_data
else
:
else
:
raise
ValueError
(
f
"Unsupported quantized tensor type:
{
type
(
quantized_tensor
)
}
"
)
raise
ValueError
(
f
"Unsupported quantized tensor type:
{
type
(
quantized_tensor
)
}
"
)
...
@@ -435,15 +445,15 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
...
@@ -435,15 +445,15 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
preserve_high_precision_init_val
=
True
,
preserve_high_precision_init_val
=
True
,
):
):
model_fp8
=
nn
.
Sequential
(
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
)
# Create model with BF16 weights
# Create model with BF16 weights
model
=
nn
.
Sequential
(
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
)
...
@@ -539,12 +549,13 @@ def _test_zero_1(dp_group):
...
@@ -539,12 +549,13 @@ def _test_zero_1(dp_group):
def
quantization_recipe
(
quantization
)
->
Recipe
:
def
quantization_recipe
(
quantization
)
->
Recipe
:
"""Quantization recipe setup"""
"""Quantization recipe setup"""
fp8_format
=
Format
.
HYBRID
if
quantization
==
"fp8"
:
if
quantization
==
"fp8"
:
return
DelayedScaling
(
return
DelayedScaling
(
fp8_format
=
fp8_format
,
amax_history_len
=
32
,
amax_compute_algo
=
"max"
)
fp8_format
=
Format
.
HYBRID
,
amax_history_len
=
32
,
amax_compute_algo
=
"max"
)
elif
quantization
==
"fp8_cs"
:
elif
quantization
==
"fp8_cs"
:
return
Float8CurrentScaling
()
return
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
elif
quantization
==
"fp8_block"
:
return
Float8BlockScaling
(
fp8_format
=
fp8_format
)
else
:
else
:
raise
ValueError
(
f
"Unsupported quantization:
{
quantization
}
"
)
raise
ValueError
(
f
"Unsupported quantization:
{
quantization
}
"
)
...
@@ -568,15 +579,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
...
@@ -568,15 +579,15 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
preserve_high_precision_init_val
=
True
,
preserve_high_precision_init_val
=
True
,
):
):
model_fp8
=
nn
.
Sequential
(
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
)
# Create model with BF16 weights
# Create model with BF16 weights
model
=
nn
.
Sequential
(
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
128
,
256
+
16
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
+
16
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
)
...
@@ -593,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
...
@@ -593,7 +604,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
optimizer_fp8
=
MiniZero_1
([
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
)
optimizer_fp8
=
MiniZero_1
([
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
)
optimizer
=
MiniZero_1
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
optimizer
=
MiniZero_1
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
_
in
range
(
100
):
for
i
in
range
(
100
):
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
w_fp8
.
main_grad
.
zero_
()
w_fp8
.
main_grad
.
zero_
()
w
.
main_grad
.
zero_
()
w
.
main_grad
.
zero_
()
...
@@ -654,7 +665,9 @@ def main(argv=None, namespace=None):
...
@@ -654,7 +665,9 @@ def main(argv=None, namespace=None):
dist
.
init_process_group
(
**
dist_init_kwargs
)
dist
.
init_process_group
(
**
dist_init_kwargs
)
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
None
,
choices
=
[
"fp8"
,
"fp8_cs"
])
parser
.
add_argument
(
"--quantization"
,
type
=
str
,
default
=
None
,
choices
=
[
"fp8"
,
"fp8_cs"
,
"fp8_block"
]
)
args
=
parser
.
parse_args
(
argv
,
namespace
)
args
=
parser
.
parse_args
(
argv
,
namespace
)
dp_group
=
dist
.
new_group
(
backend
=
"nccl"
)
dp_group
=
dist
.
new_group
(
backend
=
"nccl"
)
...
...
tests/pytorch/distributed/run_gemm_with_overlap.py
View file @
f8c2af4c
...
@@ -21,7 +21,11 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
...
@@ -21,7 +21,11 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch.cpp_extensions
as
tex
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.module.base
import
get_cublas_workspace_size_bytes
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.module.base
import
(
fill_userbuffers_buffer_for_all_gather
,
get_cublas_workspace_size_bytes
,
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
DeprecationWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
DeprecationWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
...
@@ -57,7 +61,11 @@ def _parse_args(argv=None, namespace=None):
...
@@ -57,7 +61,11 @@ def _parse_args(argv=None, namespace=None):
)
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"RNG seed."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
42
,
help
=
"RNG seed."
)
parser
.
add_argument
(
parser
.
add_argument
(
"--fp8"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Enables the te.fp8_autocast() context."
"--quantization"
,
type
=
str
.
lower
,
default
=
"none"
,
choices
=
[
"none"
,
"fp8"
,
"mxfp8"
],
help
=
"Quantization recipe"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--fp8-output"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Get FP8 output from GEMM."
"--fp8-output"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Get FP8 output from GEMM."
...
@@ -155,9 +163,9 @@ def _parse_args(argv=None, namespace=None):
...
@@ -155,9 +163,9 @@ def _parse_args(argv=None, namespace=None):
if
opts
.
atomic
:
if
opts
.
atomic
:
warnings
.
warn
(
"Atomic GEMM is not supported with bulk overlap."
)
warnings
.
warn
(
"Atomic GEMM is not supported with bulk overlap."
)
opts
.
atomic
=
False
opts
.
atomic
=
False
if
opts
.
fp8
:
if
opts
.
quantization
!=
"none"
:
warnings
.
warn
(
"Bulk overlap is supported in FP8 but only tested in BF16."
)
warnings
.
warn
(
"Bulk overlap is supported in FP8 but only tested in BF16."
)
opts
.
fp8
=
False
opts
.
quantization
=
"none"
elif
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
elif
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
if
opts
.
atomic
:
if
opts
.
atomic
:
setattr
(
opts
,
"atomic_rs_p2p"
,
opts
.
p2p
)
setattr
(
opts
,
"atomic_rs_p2p"
,
opts
.
p2p
)
...
@@ -165,8 +173,11 @@ def _parse_args(argv=None, namespace=None):
...
@@ -165,8 +173,11 @@ def _parse_args(argv=None, namespace=None):
if
opts
.
atomic
:
if
opts
.
atomic
:
if
not
te
.
fp8
.
check_fp8_support
():
if
not
te
.
fp8
.
check_fp8_support
():
assert
not
opts
.
fp8
,
"Atomic GEMM is only supported in FP8."
assert
opts
.
quantization
==
"none"
,
"Atomic GEMM is only supported in FP8."
opts
.
fp8
=
True
opts
.
quantization
=
"fp8"
if
opts
.
fp8_output
:
assert
ops
.
quantization
==
"fp8"
,
"FP8 output is only supported with FP8 compute."
return
opts
return
opts
...
@@ -303,7 +314,11 @@ def _main(opts):
...
@@ -303,7 +314,11 @@ def _main(opts):
inp_shape
=
(
opts
.
seq_length
,
opts
.
batch_size
,
hidden_size
)
inp_shape
=
(
opts
.
seq_length
,
opts
.
batch_size
,
hidden_size
)
outer_size
=
reduce
(
operator
.
mul
,
inp_shape
[:
-
1
],
1
)
outer_size
=
reduce
(
operator
.
mul
,
inp_shape
[:
-
1
],
1
)
buffer_dtype
=
torch
.
bfloat16
buffer_dtype
=
torch
.
bfloat16
if
opts
.
fp8
and
not
opts
.
bulk_overlap
and
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
if
(
opts
.
quantization
!=
"none"
and
not
opts
.
bulk_overlap
and
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
):
buffer_dtype
=
torch
.
uint8
buffer_dtype
=
torch
.
uint8
ub_obj
=
(
ub_obj
=
(
tex
.
CommOverlapP2P
(
tex
.
CommOverlapP2P
(
...
@@ -450,6 +465,8 @@ def _main(opts):
...
@@ -450,6 +465,8 @@ def _main(opts):
inp2_g
=
torch
.
nn
.
functional
.
gelu
(
ref_g
)
# pylint: disable=not-callable
inp2_g
=
torch
.
nn
.
functional
.
gelu
(
ref_g
)
# pylint: disable=not-callable
ref2_g
=
torch
.
matmul
(
inp2_g
,
ker2_g
)
ref2_g
=
torch
.
matmul
(
inp2_g
,
ker2_g
)
# Initialize quantizers
with_quantized_compute
=
opts
.
quantization
!=
"none"
inp_quantizer
=
None
inp_quantizer
=
None
ker_quantizer
=
None
ker_quantizer
=
None
out_quantizer
=
None
out_quantizer
=
None
...
@@ -457,7 +474,7 @@ def _main(opts):
...
@@ -457,7 +474,7 @@ def _main(opts):
inp2_quantizer
=
None
inp2_quantizer
=
None
ker2_quantizer
=
None
ker2_quantizer
=
None
out2_quantizer
=
None
out2_quantizer
=
None
if
opts
.
fp8
:
if
opts
.
quantization
==
"
fp8
"
:
# Structure to maintain amax and scale/scale_inv information for the kernel and input
# Structure to maintain amax and scale/scale_inv information for the kernel and input
num_gemms
=
6
if
ub_obj2
is
not
None
else
3
num_gemms
=
6
if
ub_obj2
is
not
None
else
3
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
...
@@ -502,11 +519,23 @@ def _main(opts):
...
@@ -502,11 +519,23 @@ def _main(opts):
out2_quantizer
=
Float8Quantizer
(
out2_quantizer
=
Float8Quantizer
(
fp8_scales
[
5
].
clone
(),
fp8_amaxes
[
5
].
clone
(),
fp8_dtype
fp8_scales
[
5
].
clone
(),
fp8_amaxes
[
5
].
clone
(),
fp8_dtype
)
)
elif
opts
.
quantization
==
"mxfp8"
:
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
inp_quantizer
=
MXFP8Quantizer
(
fp8_dtype
,
columnwise
=
False
)
ker_quantizer
=
MXFP8Quantizer
(
fp8_dtype
)
if
opts
.
bulk_overlap
and
opts
.
comm_type
==
tex
.
CommOverlapType
.
RS
:
bulk_inp_quantizer
=
MXFP8Quantizer
(
fp8_dtype
,
columnwise
=
False
)
elif
ub_obj2
is
not
None
:
inp2_quantizer
=
MXFP8Quantizer
(
fp8_dtype
,
columnwise
=
False
)
ker2_quantizer
=
MXFP8Quantizer
(
fp8_dtype
)
# Quantize tensors
if
with_quantized_compute
:
#
Cast input to Float8T
ensor
#
Quantize input t
ensor
inp_fp8
=
inp_quantizer
(
inp
)
inp_fp8
=
inp_quantizer
(
inp
)
#
Cast
kernel t
o Float8T
ensor
#
Quantize
kernel tensor
kernel_t_fp8
=
ker_quantizer
(
kernel_t
)
kernel_t_fp8
=
ker_quantizer
(
kernel_t
)
if
opts
.
bulk_overlap
and
opts
.
comm_type
==
tex
.
CommOverlapType
.
RS
:
if
opts
.
bulk_overlap
and
opts
.
comm_type
==
tex
.
CommOverlapType
.
RS
:
bulk_inp_fp8
=
bulk_inp_quantizer
(
bulk_inp
)
bulk_inp_fp8
=
bulk_inp_quantizer
(
bulk_inp
)
...
@@ -543,31 +572,40 @@ def _main(opts):
...
@@ -543,31 +572,40 @@ def _main(opts):
)
)
# Set up comm/compute buffers
# Set up comm/compute buffers
ag_out
=
None
rs_out
=
None
rs_out
=
None
rs_out2
=
None
rs_out2
=
None
if
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
if
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
if
opts
.
bulk_overlap
:
if
opts
.
bulk_overlap
:
ub_obj
.
copy_into_buffer
(
bulk_inp
,
bulk_inp_quantizer
,
True
)
ag_out
,
_
=
fill_userbuffers_buffer_for_all_gather
(
ub_obj
,
bulk_inp
,
bulk_inp_quantizer
,
tp_group
,
)
gemm_inp
=
inp
gemm_inp
=
inp
else
:
else
:
ub_obj
.
copy_into_buffer
(
inp_fp8
if
opts
.
fp8
else
inp
,
inp_quantizer
,
True
)
ag_out
,
_
=
fill_userbuffers_buffer_for_all_gather
(
gemm_inp
=
ub_obj
.
get_buffer
(
inp_quantizer
,
False
,
inp_g
.
size
())
ub_obj
,
inp_fp8
if
with_quantized_compute
else
inp
,
inp_quantizer
,
tp_group
,
)
gemm_inp
=
ag_out
if
ub_obj2
is
not
None
:
if
ub_obj2
is
not
None
:
if
opts
.
fp8
and
opts
.
fp8_output
:
ub_obj2
.
set_buffer_params
(
out_quantizer
)
rs_out2
=
torch
.
empty
(
rs_out2
=
torch
.
empty
(
(
outer_size
//
tp_size
,
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
(
outer_size
//
tp_size
,
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
)
else
:
else
:
if
opts
.
bulk_overlap
:
if
opts
.
bulk_overlap
:
ub_obj
.
copy_into_buffer
(
if
opts
.
quantization
==
"none"
:
bulk_inp_fp8
if
opts
.
fp8
else
bulk_inp
,
bulk_inp_quantizer
,
False
ub_obj
.
copy_into_buffer
(
bulk_inp
,
local_chunk
=
False
)
)
if
opts
.
quantization
==
"fp8"
:
if
opts
.
fp8
:
ub_obj
.
copy_into_buffer
(
bulk_inp_fp8
.
_data
,
local_chunk
=
False
)
ub_obj
.
set_buffer_params
(
bulk_inp_quantizer
)
elif
opts
.
quantization
==
"mxfp8"
:
elif
opts
.
fp8
and
opts
.
fp8_output
:
ub_obj
.
copy_into_buffer
(
bulk_inp_fp8
.
_rowwise_data
,
local_chunk
=
False
)
ub_obj
.
set_buffer_params
(
out_quantizer
)
gemm_inp
=
inp_fp8
if
opts
.
fp8
else
inp
gemm_inp
=
inp_fp8
if
with_quantized_compute
else
inp
rs_out
=
torch
.
empty
(
rs_out
=
torch
.
empty
(
(
outer_size
//
tp_size
,
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
(
outer_size
//
tp_size
,
hidden_size
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
)
...
@@ -626,7 +664,7 @@ def _main(opts):
...
@@ -626,7 +664,7 @@ def _main(opts):
if
opts
.
use_cuda_graphs
:
if
opts
.
use_cuda_graphs
:
# Trace the CUDA graph first
# Trace the CUDA graph first
g
=
torch
.
cuda
.
CUDAGraph
()
g
=
torch
.
cuda
.
CUDAGraph
()
if
opts
.
fp8
:
if
with_quantized_compute
:
if
ub_obj
is
None
:
if
ub_obj
is
None
:
with
torch
.
cuda
.
graph
(
g
):
with
torch
.
cuda
.
graph
(
g
):
all_outputs
=
_fp8_gemm
()
all_outputs
=
_fp8_gemm
()
...
@@ -646,7 +684,7 @@ def _main(opts):
...
@@ -646,7 +684,7 @@ def _main(opts):
else
:
else
:
for
i
in
range
(
total_iters
):
for
i
in
range
(
total_iters
):
if
opts
.
fp8
:
if
with_quantized_compute
:
start_events
[
i
].
record
()
start_events
[
i
].
record
()
all_outputs
=
_fp8_gemm
()
all_outputs
=
_fp8_gemm
()
end_events
[
i
].
record
()
end_events
[
i
].
record
()
...
@@ -691,10 +729,22 @@ def _main(opts):
...
@@ -691,10 +729,22 @@ def _main(opts):
output_info
=
""
output_info
=
""
if
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
if
opts
.
comm_type
==
tex
.
CommOverlapType
.
AG
:
# Bulk overlap AG output is already gathered
# Bulk overlap AG output is already gathered
test_out
=
ub_obj
.
get_buffer
(
bulk_inp_quantizer
,
False
)
test_out
=
ag_out
if
bulk_inp_quantizer
is
None
:
test_out
=
ub_obj
.
get_buffer
(
False
)
else
:
test_out
=
Float8Tensor
(
shape
=
test_out
.
shape
,
dtype
=
torch
.
bfloat16
,
data
=
ub_obj
.
get_buffer
(
False
),
fp8_scale
=
bulk_inp_quantizer
.
scale
,
fp8_dtype
=
bulk_inp_quantizer
.
dtype
,
quantizer
=
bulk_inp_quantizer
,
)
else
:
else
:
# Bulk overlap RS output needs to be gathered
# Bulk overlap RS output needs to be gathered
out_local
=
ub_obj
.
get_buffer
(
bulk_inp_quantizer
,
True
)
out_local
=
ub_obj
.
get_buffer
(
True
)
output_info
+=
f
"rs_output:
{
list
(
out_local
.
shape
)
}
| "
output_info
+=
f
"rs_output:
{
list
(
out_local
.
shape
)
}
| "
test_out
=
te
.
distributed
.
gather_along_first_dim
(
out_local
,
tp_group
)[
0
]
test_out
=
te
.
distributed
.
gather_along_first_dim
(
out_local
,
tp_group
)[
0
]
...
@@ -765,8 +815,8 @@ def _main(opts):
...
@@ -765,8 +815,8 @@ def _main(opts):
m
=
torch
.
argmax
(
diff
)
m
=
torch
.
argmax
(
diff
)
abs_err
=
diff
[
m
].
item
()
abs_err
=
diff
[
m
].
item
()
rel_err
=
abs_err
/
max
(
abs
(
ref_out
.
flatten
()[
m
].
item
()),
1e-5
)
rel_err
=
abs_err
/
max
(
abs
(
ref_out
.
flatten
()[
m
].
item
()),
1e-5
)
rtol
=
0.
125
if
opts
.
fp8
else
0.
02
rtol
=
0.
02
if
opts
.
quantization
==
"none"
else
0.
125
atol
=
0.0
625
if
opts
.
fp8
else
0.0
01
atol
=
0.0
01
if
opts
.
quantization
==
"none"
else
0.0
625
if
rel_err
>
rtol
and
abs_err
>
atol
:
if
rel_err
>
rtol
and
abs_err
>
atol
:
numerics_failed
=
True
numerics_failed
=
True
numerics_info
=
(
numerics_info
=
(
...
...
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
f8c2af4c
...
@@ -17,7 +17,12 @@ import torch
...
@@ -17,7 +17,12 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
Format
,
DelayedScaling
,
Float8CurrentScaling
from
transformer_engine.common.recipe
import
(
DelayedScaling
,
Float8CurrentScaling
,
Format
,
MXFP8BlockScaling
,
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
DeprecationWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
DeprecationWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
...
@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None):
...
@@ -163,7 +168,7 @@ def _parse_args(argv=None, namespace=None):
"--quantization"
,
"--quantization"
,
type
=
str
.
lower
,
type
=
str
.
lower
,
default
=
"none"
,
default
=
"none"
,
choices
=
[
"none"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
],
choices
=
[
"none"
,
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
],
help
=
"Quantization recipe"
,
help
=
"Quantization recipe"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -414,6 +419,8 @@ def _train(opts):
...
@@ -414,6 +419,8 @@ def _train(opts):
)
)
elif
opts
.
quantization
==
"fp8_current_scaling"
:
elif
opts
.
quantization
==
"fp8_current_scaling"
:
fp8_recipe
=
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
fp8_recipe
=
Float8CurrentScaling
(
fp8_format
=
fp8_format
)
elif
opts
.
quantization
==
"mxfp8"
:
fp8_recipe
=
MXFP8BlockScaling
()
# Prepare random input tensors
# Prepare random input tensors
test_x
=
torch
.
randn
(
input_shape
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
True
)
test_x
=
torch
.
randn
(
input_shape
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
True
)
...
...
tests/pytorch/distributed/run_numerics.py
View file @
f8c2af4c
...
@@ -174,7 +174,7 @@ def _get_tolerances(dtype):
...
@@ -174,7 +174,7 @@ def _get_tolerances(dtype):
if
dtype
==
torch
.
bfloat16
:
if
dtype
==
torch
.
bfloat16
:
return
{
"rtol"
:
1.6e-2
,
"atol"
:
1e-5
}
return
{
"rtol"
:
1.6e-2
,
"atol"
:
1e-5
}
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
return
{
"rtol"
:
1.3e-6
,
"atol"
:
1
e-5
}
return
{
"rtol"
:
1.3e-6
,
"atol"
:
4
e-5
}
raise
ValueError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
raise
ValueError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
...
...
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
View file @
f8c2af4c
...
@@ -15,6 +15,9 @@ if torch.cuda.device_count() < 2:
...
@@ -15,6 +15,9 @@ if torch.cuda.device_count() < 2:
pytest
.
skip
(
"cast_master_weights_to_fp8 test needs at least 2 GPUs."
)
pytest
.
skip
(
"cast_master_weights_to_fp8 test needs at least 2 GPUs."
)
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
NUM_PROCS
:
int
=
min
(
2
,
torch
.
cuda
.
device_count
())
NUM_PROCS
:
int
=
min
(
2
,
torch
.
cuda
.
device_count
())
...
@@ -28,8 +31,10 @@ def _run_test(quantization):
...
@@ -28,8 +31,10 @@ def _run_test(quantization):
assert
result
.
returncode
==
0
assert
result
.
returncode
==
0
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
"fp8"
,
"fp8_cs"
])
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
"fp8"
,
"fp8_cs"
,
"fp8_block"
])
def
test_cast_master_weights_to_fp8
(
quantization
):
def
test_cast_master_weights_to_fp8
(
quantization
):
if
not
fp8_available
:
if
quantization
in
(
"fp8"
,
"fp8_cs"
)
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"fp8_block"
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
_run_test
(
quantization
)
_run_test
(
quantization
)
tests/pytorch/distributed/test_comm_gemm_overlap.py
View file @
f8c2af4c
...
@@ -21,6 +21,7 @@ if torch.cuda.device_count() < 2:
...
@@ -21,6 +21,7 @@ if torch.cuda.device_count() < 2:
pytest
.
skip
(
"Comm+GEMM overlap requires at least 2 GPUs."
)
pytest
.
skip
(
"Comm+GEMM overlap requires at least 2 GPUs."
)
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
RNG_SEED
:
int
=
42
RNG_SEED
:
int
=
42
SEQ_LENGTH
:
int
=
1024
SEQ_LENGTH
:
int
=
1024
...
@@ -56,7 +57,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
...
@@ -56,7 +57,7 @@ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
torch
.
_dynamo
.
reset
()
torch
.
_dynamo
.
reset
()
def
_run_gemm_with_overlap
(
comm_type
,
bulk
,
p2p
,
atomic
,
fp8
):
def
_run_gemm_with_overlap
(
comm_type
,
bulk
,
p2p
,
atomic
,
quantization
):
test_path
=
TEST_ROOT
/
"run_gemm_with_overlap.py"
test_path
=
TEST_ROOT
/
"run_gemm_with_overlap.py"
test_cmd
=
LAUNCH_CMD
+
[
test_cmd
=
LAUNCH_CMD
+
[
str
(
test_path
),
str
(
test_path
),
...
@@ -72,10 +73,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
...
@@ -72,10 +73,11 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8):
if
bulk
:
if
bulk
:
test_cmd
.
append
(
"--bulk-overlap"
)
test_cmd
.
append
(
"--bulk-overlap"
)
else
:
else
:
if
fp8
:
if
quantization
==
"fp8"
and
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
test_cmd
.
append
(
"--fp8"
)
pytest
.
skip
(
reason_for_no_mxfp8
)
test_cmd
.
append
(
f
"--quantization=
{
quantization
}
"
)
if
p2p
:
if
p2p
:
test_cmd
.
append
(
"--p2p"
)
test_cmd
.
append
(
"--p2p"
)
if
atomic
:
if
atomic
:
...
@@ -114,8 +116,10 @@ def _run_layer_with_overlap(
...
@@ -114,8 +116,10 @@ def _run_layer_with_overlap(
test_cmd
.
append
(
"--overlap-rs-dgrad"
)
test_cmd
.
append
(
"--overlap-rs-dgrad"
)
if
fp8
:
if
fp8
:
if
not
fp8_available
:
if
quantization
in
(
"fp8_delayed_scaling"
,
"fp8_current_scaling"
)
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
test_cmd
.
append
(
"--fp8"
)
test_cmd
.
append
(
"--fp8"
)
test_cmd
.
append
(
f
"--quantization=
{
quantization
}
"
)
test_cmd
.
append
(
f
"--quantization=
{
quantization
}
"
)
...
@@ -137,51 +141,34 @@ def _run_layer_with_overlap(
...
@@ -137,51 +141,34 @@ def _run_layer_with_overlap(
raise
AssertionError
(
result
.
stderr
.
decode
())
raise
AssertionError
(
result
.
stderr
.
decode
())
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
"none"
,
"fp8"
,
"mxfp8"
))
"fp8"
,
def
test_split_all_gather_overlaps
(
quantization
):
(
False
,
True
),
ids
=
[
" BF16 - RING-EXCHANGE "
,
" FP8 - RING-EXCHANGE "
],
)
def
test_split_all_gather_overlaps
(
fp8
):
"""
"""
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
te.cpp_extensions.fp8_gemm.
"""
"""
_run_gemm_with_overlap
(
"AG"
,
False
,
True
,
False
,
fp8
)
_run_gemm_with_overlap
(
"AG"
,
False
,
True
,
False
,
quantization
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
"none"
,
"fp8"
,
"mxfp8"
))
"fp8,p2p"
,
@
pytest
.
mark
.
parametrize
(
"p2p"
,
(
False
,
True
))
[
def
test_split_reduce_scatter_overlaps
(
quantization
,
p2p
):
(
False
,
False
),
(
False
,
True
),
(
True
,
False
),
(
True
,
True
),
],
ids
=
[
" BF16 - PIPELINE "
,
" BF16 - RING-EXCHANGE "
,
" FP8 - PIPELINE "
,
" FP8 - RING-EXCHANGE "
,
],
)
def
test_split_reduce_scatter_overlaps
(
fp8
,
p2p
):
"""
"""
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or
te.cpp_extensions.fp8_gemm.
te.cpp_extensions.fp8_gemm.
"""
"""
_run_gemm_with_overlap
(
"RS"
,
False
,
p2p
,
False
,
fp8
)
_run_gemm_with_overlap
(
"RS"
,
False
,
p2p
,
False
,
quantization
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"comm_type,
fp8
, connections"
,
"comm_type,
quantization
, connections"
,
[
[
(
"AG"
,
False
,
1
),
(
"AG"
,
"none"
,
1
),
(
"RS"
,
False
,
1
),
(
"RS"
,
"none"
,
1
),
(
"RS"
,
True
,
1
),
(
"RS"
,
"fp8"
,
1
),
(
"AG"
,
False
,
8
),
(
"AG"
,
"none"
,
8
),
(
"RS"
,
False
,
8
),
(
"RS"
,
"none"
,
8
),
(
"RS"
,
True
,
8
),
(
"RS"
,
"fp8"
,
8
),
],
],
ids
=
[
ids
=
[
"ALL-GATHER - BF16 - 1 connections"
,
"ALL-GATHER - BF16 - 1 connections"
,
...
@@ -192,7 +179,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p):
...
@@ -192,7 +179,7 @@ def test_split_reduce_scatter_overlaps(fp8, p2p):
"REDUCE-SCATTER - FP8 - 8 connections"
,
"REDUCE-SCATTER - FP8 - 8 connections"
,
],
],
)
)
def
test_bulk_overlaps
(
comm_type
,
fp8
,
connections
):
def
test_bulk_overlaps
(
comm_type
,
quantization
,
connections
):
"""
"""
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm.
"""
"""
...
@@ -203,10 +190,10 @@ def test_bulk_overlaps(comm_type, fp8, connections):
...
@@ -203,10 +190,10 @@ def test_bulk_overlaps(comm_type, fp8, connections):
" 9.0 (HOPPER ARCH)."
" 9.0 (HOPPER ARCH)."
)
)
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"8"
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"8"
_run_gemm_with_overlap
(
comm_type
,
True
,
False
,
False
,
fp8
)
_run_gemm_with_overlap
(
comm_type
,
True
,
False
,
False
,
quantization
)
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"1"
os
.
environ
[
"CUDA_DEVICE_MAX_CONNECTIONS"
]
=
"1"
else
:
else
:
_run_gemm_with_overlap
(
comm_type
,
True
,
False
,
False
,
fp8
)
_run_gemm_with_overlap
(
comm_type
,
True
,
False
,
False
,
quantization
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -258,15 +245,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
...
@@ -258,15 +245,7 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"quantization"
,
"quantization"
,
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
],
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
],
ids
=
[
" DELAYED SCALING "
,
" CURRENT SCALING "
],
)
@
pytest
.
mark
.
parametrize
(
"fp8"
,
(
True
,),
ids
=
[
" FP8 "
,
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"layer_type,linear_parallel_mode,overlap_rs_dgrad"
,
"layer_type,linear_parallel_mode,overlap_rs_dgrad"
,
...
@@ -286,15 +265,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
...
@@ -286,15 +265,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
)
)
),
),
ids
=
[
ids
=
[
f
"
{
te
.
Linear
.
__name__
}
- ROW-PARALLEL
"
,
f
"
{
te
.
Linear
.
__name__
}
-row_tensor_parallel
"
,
f
"
{
te
.
Linear
.
__name__
}
- COL-PARALLEL -
BULK DGRAD/WGRAD
"
,
f
"
{
te
.
Linear
.
__name__
}
-col_tensor_parallel-
BULK DGRAD/WGRAD"
,
f
"
{
te
.
Linear
.
__name__
}
- COL-PARLALEL -
DGRAD+RS
"
,
f
"
{
te
.
Linear
.
__name__
}
-col_tensor_parallel-
DGRAD+RS"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
- ROW-PARALLEL
"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
-row_tensor_parallel
"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
- COL-PARALLEL -
BULK DGRAD/WGRAD
"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
-col_tensor_parallel-
BULK DGRAD/WGRAD"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
- COL-PARALLEL -
DGRAD+RS
"
,
f
"
{
te
.
LayerNormLinear
.
__name__
}
-col_tensor_parallel-
DGRAD+RS"
,
]
]
+
[
+
[
"
"
+
" -
"
.
join
(
test_name_parts
)
+
" "
"
-
"
.
join
(
test_name_parts
)
for
test_name_parts
in
zip
(
for
test_name_parts
in
zip
(
[
layer
.
__name__
for
layer
in
TE_LAYERS
[
2
:]
for
_
in
range
(
2
)],
[
layer
.
__name__
for
layer
in
TE_LAYERS
[
2
:]
for
_
in
range
(
2
)],
[
"BULK DGRAD/WGRAD"
,
"DGRAD+RS"
]
*
len
(
TE_LAYERS
[
2
:]),
[
"BULK DGRAD/WGRAD"
,
"DGRAD+RS"
]
*
len
(
TE_LAYERS
[
2
:]),
...
@@ -302,12 +281,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
...
@@ -302,12 +281,15 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d
],
],
)
)
def
test_layers_with_overlap_fp8
(
def
test_layers_with_overlap_fp8
(
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
fp8
,
quantization
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
quantization
,
):
):
"""
"""
Test Transformer Engine layers with comm+GEMM overlap.
Test Transformer Engine layers with comm+GEMM overlap.
"""
"""
_run_layer_with_overlap
(
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
fp8
,
quantization
)
_run_layer_with_overlap
(
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
True
,
quantization
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -354,22 +336,11 @@ def test_multi_layer_with_overlap_bf16(
...
@@ -354,22 +336,11 @@ def test_multi_layer_with_overlap_bf16(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"quantization"
,
"quantization"
,
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
],
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
],
ids
=
[
" DELAYED SCALING "
,
" CURRENT SCALING "
],
)
@
pytest
.
mark
.
parametrize
(
"fp8"
,
(
True
,),
ids
=
[
" FP8 "
,
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"num_layers"
,
"num_layers"
,
(
2
,),
(
2
,),
ids
=
[
" 2 layers "
,
],
)
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"layer_type,linear_parallel_mode,overlap_rs_dgrad"
,
"layer_type,linear_parallel_mode,overlap_rs_dgrad"
,
...
@@ -381,7 +352,7 @@ def test_multi_layer_with_overlap_bf16(
...
@@ -381,7 +352,7 @@ def test_multi_layer_with_overlap_bf16(
)
)
),
),
ids
=
[
ids
=
[
"
"
+
" -
"
.
join
(
test_name_parts
)
+
" "
"
-
"
.
join
(
test_name_parts
)
for
test_name_parts
in
zip
(
for
test_name_parts
in
zip
(
[
te
.
TransformerLayer
.
__name__
for
_
in
range
(
2
)],
[
te
.
TransformerLayer
.
__name__
for
_
in
range
(
2
)],
[
"BULK DGRAD/WGRAD"
,
"DGRAD+RS"
],
[
"BULK DGRAD/WGRAD"
,
"DGRAD+RS"
],
...
@@ -389,11 +360,11 @@ def test_multi_layer_with_overlap_bf16(
...
@@ -389,11 +360,11 @@ def test_multi_layer_with_overlap_bf16(
],
],
)
)
def
test_multi_layer_with_overlap_fp8
(
def
test_multi_layer_with_overlap_fp8
(
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
fp8
,
quantization
,
num_layers
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
quantization
,
num_layers
):
):
"""
"""
Test Transformer Engine layers with comm+GEMM overlap.
Test Transformer Engine layers with comm+GEMM overlap.
"""
"""
_run_layer_with_overlap
(
_run_layer_with_overlap
(
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
fp8
,
quantization
,
num_layers
layer_type
,
linear_parallel_mode
,
overlap_rs_dgrad
,
True
,
quantization
,
num_layers
)
)
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
f8c2af4c
...
@@ -19,7 +19,6 @@ import torch
...
@@ -19,7 +19,6 @@ import torch
import
transformer_engine
import
transformer_engine
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch.cpp_extensions
as
tex
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
...
@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import (
...
@@ -27,6 +26,8 @@ from transformer_engine.pytorch.ops.fused import (
UserbuffersBackwardLinear
,
UserbuffersBackwardLinear
,
UserbuffersForwardLinear
,
UserbuffersForwardLinear
,
)
)
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensor
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
# Import utility functions
# Import utility functions
...
@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype
...
@@ -36,6 +37,13 @@ from utils import dtype_tols, str_to_dtype
# Check if FP8 is supported
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
quantization_list
:
list
[
Optional
[
str
]]
=
[
None
]
if
fp8_available
:
quantization_list
.
append
(
"fp8"
)
if
mxfp8_available
:
quantization_list
.
append
(
"mxfp8"
)
# Check if there are multiple GPUs
# Check if there are multiple GPUs
if
torch
.
cuda
.
device_count
()
<
2
:
if
torch
.
cuda
.
device_count
()
<
2
:
...
@@ -51,7 +59,7 @@ class ModelConfig:
...
@@ -51,7 +59,7 @@ class ModelConfig:
num_heads
:
int
num_heads
:
int
head_dim
:
int
head_dim
:
int
dtype
:
torch
.
dtype
dtype
:
torch
.
dtype
fp8
:
bool
quantization
:
Optional
[
str
]
@
property
@
property
def
hidden_size
(
self
):
def
hidden_size
(
self
):
...
@@ -129,12 +137,16 @@ def make_reference_and_test_tensors(
...
@@ -129,12 +137,16 @@ def make_reference_and_test_tensors(
ref
=
torch
.
rand
(
shape
,
dtype
=
ref_dtype
,
device
=
ref_device
)
ref
=
torch
.
rand
(
shape
,
dtype
=
ref_dtype
,
device
=
ref_device
)
# Make copy of tensor
# Make copy of tensor
test
=
ref
.
to
(
device
=
test_device
,
dtype
=
test_dtype
)
if
test_is_fp8
:
if
test_is_fp8
:
test
=
Float8Tensor
.
to_float8
(
ref
)
quantizer
=
Float8Quantizer
(
else
:
scale
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
test_device
),
test
=
ref
.
to
(
device
=
test_device
,
dtype
=
test_dtype
)
amax
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
test_device
),
if
test
.
data_ptr
()
==
ref
.
data_ptr
():
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
test
=
test
.
clone
()
)
test
=
quantizer
(
test
)
elif
test
.
data_ptr
()
==
ref
.
data_ptr
():
test
=
test
.
clone
()
# Make sure reference and test tensors represent exact same values
# Make sure reference and test tensors represent exact same values
ref
.
copy_
(
test
)
ref
.
copy_
(
test
)
...
@@ -145,6 +157,21 @@ def make_reference_and_test_tensors(
...
@@ -145,6 +157,21 @@ def make_reference_and_test_tensors(
return
ref
,
test
return
ref
,
test
def
make_recipe
(
name
:
Optional
[
str
]
=
None
)
->
Optional
[
Recipe
]:
"""Make recipe for quantization scheme"""
if
name
is
None
:
return
None
if
name
==
"fp8"
:
return
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
,
)
if
name
==
"mxfp8"
:
return
transformer_engine
.
common
.
recipe
.
MXFP8BlockScaling
(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
,
)
raise
ValueError
(
f
"Unsupported quantization scheme (
{
name
}
)"
)
def
_test_linear
(
def
_test_linear
(
*
,
*
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
...
@@ -155,7 +182,8 @@ def _test_linear(
...
@@ -155,7 +182,8 @@ def _test_linear(
weight_requires_grad
:
bool
=
True
,
weight_requires_grad
:
bool
=
True
,
)
->
None
:
)
->
None
:
dtype
=
model_config
.
dtype
dtype
=
model_config
.
dtype
fp8_compute
=
model_config
.
fp8
quantization
=
model_config
.
quantization
quantized_compute
=
quantization
is
not
None
# Distributed process group
# Distributed process group
process_group
=
world_group
()
process_group
=
world_group
()
...
@@ -175,14 +203,19 @@ def _test_linear(
...
@@ -175,14 +203,19 @@ def _test_linear(
in_shape
,
in_shape
,
test_dtype
=
dtype
,
test_dtype
=
dtype
,
test_device
=
device
,
test_device
=
device
,
test_is_fp8
=
fp8
_compute
,
test_is_fp8
=
quantized
_compute
,
)
)
if
isinstance
(
x_test
,
QuantizedTensor
):
with
torch
.
no_grad
():
x_test
=
x_test
.
dequantize
().
requires_grad_
()
w_ref
,
w_test
=
make_reference_and_test_tensors
(
w_ref
,
w_test
=
make_reference_and_test_tensors
(
(
out_features
,
in_features
),
(
out_features
,
in_features
),
test_dtype
=
dtype
,
test_dtype
=
dtype
,
test_device
=
device
,
test_device
=
device
,
test_is_fp8
=
fp8
_compute
,
test_is_fp8
=
quantized
_compute
,
)
)
if
isinstance
(
w_test
,
QuantizedTensor
):
w_test
=
w_test
.
dequantize
()
b_ref
,
b_test
=
None
,
None
b_ref
,
b_test
=
None
,
None
if
bias
:
if
bias
:
if
tensor_parallel_mode
==
"row"
:
if
tensor_parallel_mode
==
"row"
:
...
@@ -198,9 +231,11 @@ def _test_linear(
...
@@ -198,9 +231,11 @@ def _test_linear(
out_shape
,
out_shape
,
test_dtype
=
dtype
,
test_dtype
=
dtype
,
test_device
=
device
,
test_device
=
device
,
test_is_fp8
=
fp8
_compute
,
test_is_fp8
=
quantized
_compute
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
if
isinstance
(
dy_test
,
QuantizedTensor
):
dy_test
=
dy_test
.
dequantize
()
# Plain PyTorch implementation
# Plain PyTorch implementation
y_ref
=
torch
.
nn
.
functional
.
linear
(
x_ref
,
w_ref
)
y_ref
=
torch
.
nn
.
functional
.
linear
(
x_ref
,
w_ref
)
...
@@ -265,21 +300,15 @@ def _test_linear(
...
@@ -265,21 +300,15 @@ def _test_linear(
x_test
.
requires_grad_
()
x_test
.
requires_grad_
()
# Implementation with fusible operation
# Implementation with fusible operation
with
te
.
fp8_model_init
(
enabled
=
fp8_compute
):
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_compute
,
recipe
=
recipe
):
ops
=
[]
ops
=
[]
linear_op
=
None
linear_op
=
None
bias_op
=
None
bias_op
=
None
if
tensor_parallel_mode
==
"column"
:
if
tensor_parallel_mode
==
"column"
:
userbuffers_options
=
{}
userbuffers_options
=
{}
if
not
weight_requires_grad
:
if
not
weight_requires_grad
:
if
fp8_compute
:
userbuffers_options
[
"comm_name"
]
=
"fc1"
userbuffers_options
[
"comm_name"
]
=
"fc1"
else
:
# There is a correctness bug with overlapping
# dgrad reduce-scatter with dgrad GEMM. Fall back
# to overlapping dgrad reduce-scatter with wgrad
# GEMM, even though wgrad isn't needed.
userbuffers_options
[
"comm_name"
]
=
"qkv"
else
:
else
:
userbuffers_options
[
"comm_name"
]
=
"qkv"
userbuffers_options
[
"comm_name"
]
=
"qkv"
linear_op
=
te_ops
.
BasicLinear
(
linear_op
=
te_ops
.
BasicLinear
(
...
@@ -322,7 +351,7 @@ def _test_linear(
...
@@ -322,7 +351,7 @@ def _test_linear(
bias_op
.
bias
.
copy_
(
b_test
)
bias_op
.
bias
.
copy_
(
b_test
)
del
w_test
del
w_test
del
b_test
del
b_test
with
te
.
fp8_autocast
(
enabled
=
fp8_comput
e
):
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recip
e
):
y_test
=
model
(
x_test
)
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
y_test
.
backward
(
dy_test
)
...
@@ -338,7 +367,7 @@ def _test_linear(
...
@@ -338,7 +367,7 @@ def _test_linear(
tols
=
dtype_tols
(
dtype
)
tols
=
dtype_tols
(
dtype
)
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
tols
=
dtype_tols
(
torch
.
float16
)
# TF32 GEMM
if
fp8
_compute
:
if
quantized
_compute
:
tols
=
dtype_tols
(
tols
=
dtype_tols
(
model
[
0
].
weight
.
_fp8_dtype
model
[
0
].
weight
.
_fp8_dtype
if
is_float8_tensor
(
model
[
0
].
weight
)
if
is_float8_tensor
(
model
[
0
].
weight
)
...
@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None:
...
@@ -370,7 +399,7 @@ def run_parallel_tests(model_config: ModelConfig) -> None:
for
test_config
in
itertools
.
product
(
for
test_config
in
itertools
.
product
(
(
False
,
True
),
# bias
(
False
,
True
),
# bias
(
"column"
,
"row"
),
# tensor_parallel_mode
(
"column"
,
"row"
),
# tensor_parallel_mode
(
False
,
Tru
e
),
# weight_requires_grad
(
True
,
Fals
e
),
# weight_requires_grad
):
):
if
rank
==
0
:
if
rank
==
0
:
print
(
f
"Running _test_linear with
{
test_config
=
}
"
)
print
(
f
"Running _test_linear with
{
test_config
=
}
"
)
...
@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1:
...
@@ -390,19 +419,15 @@ if torch.cuda.device_count() > 1:
@
pytest
.
mark
.
parametrize
(
"world_size"
,
_world_sizes
)
@
pytest
.
mark
.
parametrize
(
"world_size"
,
_world_sizes
)
@
pytest
.
mark
.
parametrize
(
"
fp8"
,
(
False
,
True
)
)
@
pytest
.
mark
.
parametrize
(
"
quantization"
,
quantization_list
)
def
test_fuser_ops_with_userbuffers
(
def
test_fuser_ops_with_userbuffers
(
*
,
*
,
world_size
:
int
,
world_size
:
int
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
fp8
:
bool
,
quantization
:
Optional
[
str
]
,
)
->
None
:
)
->
None
:
"""Launch parallel job and run tests"""
"""Launch parallel job and run tests"""
# Skip invalid configurations
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
# Parallel job launcher
# Parallel job launcher
command
=
[]
command
=
[]
if
tex
.
ubuf_built_with_mpi
():
if
tex
.
ubuf_built_with_mpi
():
...
@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers(
...
@@ -424,8 +449,8 @@ def test_fuser_ops_with_userbuffers(
str
(
dtype
),
str
(
dtype
),
)
)
)
)
if
fp8
:
if
quantization
is
not
None
:
command
.
app
end
(
"--
fp8"
)
command
.
ext
end
(
(
"--
quantization"
,
quantization
)
)
# Environment
# Environment
env
=
dict
(
os
.
environ
)
env
=
dict
(
os
.
environ
)
...
@@ -445,12 +470,12 @@ def main() -> None:
...
@@ -445,12 +470,12 @@ def main() -> None:
# Parse command-line arguments
# Parse command-line arguments
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--parallel"
,
action
=
"store_true"
,
help
=
"Run parallel tests"
)
parser
.
add_argument
(
"--parallel"
,
action
=
"store_true"
,
help
=
"Run parallel tests"
)
parser
.
add_argument
(
"--sequence-length"
,
type
=
int
,
default
=
3
2
)
parser
.
add_argument
(
"--sequence-length"
,
type
=
int
,
default
=
2
56
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--head-dim"
,
type
=
int
,
default
=
3
2
)
parser
.
add_argument
(
"--head-dim"
,
type
=
int
,
default
=
2
56
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"bfloat16"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"bfloat16"
)
parser
.
add_argument
(
"--
fp8"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--
quantization"
,
type
=
str
,
default
=
None
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# Run parallel tests if needed
# Run parallel tests if needed
...
@@ -463,14 +488,17 @@ def main() -> None:
...
@@ -463,14 +488,17 @@ def main() -> None:
num_heads
=
args
.
num_heads
,
num_heads
=
args
.
num_heads
,
head_dim
=
args
.
head_dim
,
head_dim
=
args
.
head_dim
,
dtype
=
str_to_dtype
(
args
.
dtype
),
dtype
=
str_to_dtype
(
args
.
dtype
),
fp8
=
args
.
fp8
,
quantization
=
args
.
quantization
,
)
)
# Initialize Userbuffers
# Initialize Userbuffers
group
=
world_group
()
# Initialize NCCL
group
=
world_group
()
# Initialize NCCL
bootstrap_backend
=
"mpi"
if
launcher
()
==
"ompi"
else
"nccl"
bootstrap_backend
=
"mpi"
if
launcher
()
==
"ompi"
else
"nccl"
userbuffer_configs
=
{
userbuffer_configs
=
{
"fc1_dgrad"
:
{
"method"
:
"pipeline"
},
# Overlap dgrad RS with dgrad GEMM
"fc1_dgrad"
:
{
"method"
:
"ring_exchange"
,
"fp8_buf"
:
False
,
},
# Overlap dgrad RS with dgrad GEMM
}
}
te
.
module
.
base
.
initialize_ub
(
te
.
module
.
base
.
initialize_ub
(
[
[
...
@@ -478,7 +506,7 @@ def main() -> None:
...
@@ -478,7 +506,7 @@ def main() -> None:
model_config
.
num_heads
*
model_config
.
head_dim
,
model_config
.
num_heads
*
model_config
.
head_dim
,
],
],
torch
.
distributed
.
get_world_size
(
group
),
torch
.
distributed
.
get_world_size
(
group
),
use_fp8
=
model_config
.
fp8
,
use_fp8
=
model_config
.
quantization
is
not
None
,
dtype
=
model_config
.
dtype
,
dtype
=
model_config
.
dtype
,
bootstrap_backend
=
bootstrap_backend
,
bootstrap_backend
=
bootstrap_backend
,
ub_cfgs
=
userbuffer_configs
,
ub_cfgs
=
userbuffer_configs
,
...
...
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
View file @
f8c2af4c
...
@@ -2,12 +2,16 @@
...
@@ -2,12 +2,16 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
import
os
,
sys
,
logging
import
os
import
sys
import
logging
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.attention
import
get_cu_seqlens_on_cp_rank
from
transformer_engine.pytorch.attention.dot_product_attention.context_parallel
import
(
get_cu_seqlens_on_cp_rank
,
)
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
test_fused_attn_with_cp
import
model_configs_flash_attn
,
model_configs_fused_attn
from
test_fused_attn_with_cp
import
model_configs_flash_attn
,
model_configs_fused_attn
from
transformer_engine.pytorch.fp8
import
fp8_autocast
from
transformer_engine.pytorch.fp8
import
fp8_autocast
...
...
tests/pytorch/fused_attn/test_fused_attn.py
View file @
f8c2af4c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
import
functools
import
logging
import
logging
import
math
import
math
import
os
import
os
from
importlib.metadata
import
version
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
,
Optional
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -16,26 +13,22 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
...
@@ -16,26 +13,22 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch
import
TransformerLayer
,
fp8_autocast
,
fp8_model_init
from
transformer_engine.pytorch
import
TransformerLayer
,
fp8_autocast
,
fp8_model_init
from
transformer_engine.pytorch.attention
import
(
from
transformer_engine.pytorch.attention
.dot_product_attention
import
(
DotProductAttention
,
DotProductAttention
,
MultiheadAttention
,
_attention_backends
,
_attention_backends
,
)
)
from
transformer_engine.pytorch.dot_product_attention.utils
import
(
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention.dot_product_attention.utils
import
(
FlashAttentionUtils
,
FlashAttentionUtils
,
get_attention_backend
,
get_attention_backend
,
check_set_window_size
,
check_set_window_size
,
AttentionParams
,
AttentionParams
,
)
)
from
transformer_engine.pytorch.dot_product_attention.inference
import
InferenceParams
from
transformer_engine.pytorch.attention
import
InferenceParams
from
transformer_engine.pytorch.dot_product_attention.rope
import
RotaryPositionEmbedding
from
transformer_engine.pytorch.attention
import
RotaryPositionEmbedding
from
transformer_engine.pytorch.constants
import
TE_DType
import
transformer_engine.pytorch.cpp_extensions
as
ext
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
(
from
transformer_engine.pytorch.cpp_extensions.fused_attn
import
(
AttnBiasType
,
AttnMaskType
,
FusedAttnBackend
,
FusedAttnBackend
,
QKVLayout
,
fused_attn_bwd
,
fused_attn_bwd
,
fused_attn_fwd
,
fused_attn_fwd
,
)
)
...
@@ -50,9 +43,7 @@ from transformer_engine.pytorch.utils import (
...
@@ -50,9 +43,7 @@ from transformer_engine.pytorch.utils import (
)
)
from
transformer_engine.pytorch.utils
import
get_cudnn_version
from
transformer_engine.pytorch.utils
import
get_cudnn_version
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
NVTE_Fused_Attn_Backend
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
Quantizer
,
prepare_for_saving
,
prepare_for_saving
,
restore_from_saved
,
restore_from_saved
,
...
@@ -1659,8 +1650,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
...
@@ -1659,8 +1650,8 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
cu_seqlens_kv
=
cu_seqlens_kv
,
)
)
if
is_training
:
if
is_training
:
out
.
backward
(
out_grad
)
out
.
backward
(
out_grad
)
param_names
=
[]
param_names
=
[]
param_names
.
append
(
"hidden_states.grad"
)
param_names
.
append
(
"hidden_states.grad"
)
...
@@ -1910,8 +1901,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
...
@@ -1910,8 +1901,8 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training):
checkpoint_core_attention
=
False
,
checkpoint_core_attention
=
False
,
core_attention_bias_type
=
config
.
attn_bias_type
,
core_attention_bias_type
=
config
.
attn_bias_type
,
)
)
if
is_training
:
if
is_training
:
out
.
backward
(
out_grad
)
out
.
backward
(
out_grad
)
if
is_training
:
if
is_training
:
return
out
,
(
inp
[
0
].
grad
,
inp
[
1
].
grad
,
inp
[
2
].
grad
)
return
out
,
(
inp
[
0
].
grad
,
inp
[
1
].
grad
,
inp
[
2
].
grad
)
...
@@ -2024,7 +2015,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
...
@@ -2024,7 +2015,7 @@ def _run_custom_mha_fp8(dtype, config, backend):
mha
=
Custom_MHA_FP8
(
config
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
mha
=
Custom_MHA_FP8
(
config
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
):
with
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
):
out
=
mha
(
inp
,
cu_seqlens
,
config
.
max_seqlen_q
)
out
=
mha
(
inp
,
cu_seqlens
,
config
.
max_seqlen_q
)
out
.
backward
(
out_grad
)
out
.
backward
(
out_grad
)
out
=
torch
.
load
(
"out.pt"
)
out
=
torch
.
load
(
"out.pt"
)
dqkv
=
torch
.
load
(
"dqkv.pt"
)
dqkv
=
torch
.
load
(
"dqkv.pt"
)
...
...
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment