Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
44740c6c
Commit
44740c6c
authored
Jul 22, 2025
by
yuguo
Browse files
Merge commit '
7a9a0825
' of...
Merge commit '
7a9a0825
' of
https://github.com/NVIDIA/TransformerEngine
parents
8113d9e0
7a9a0825
Changes
162
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1463 additions
and
276 deletions
+1463
-276
tests/cpp/operator/test_multi_unpadding.cu
tests/cpp/operator/test_multi_unpadding.cu
+186
-0
tests/cpp/operator/test_normalization.cu
tests/cpp/operator/test_normalization.cu
+2
-2
tests/cpp/util/test_string.cpp
tests/cpp/util/test_string.cpp
+1
-1
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+118
-91
tests/jax/test_distributed_layernorm.py
tests/jax/test_distributed_layernorm.py
+0
-2
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+228
-105
tests/jax/test_helper.py
tests/jax/test_helper.py
+2
-2
tests/jax/utils.py
tests/jax/utils.py
+23
-6
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+46
-27
tests/pytorch/debug/test_distributed.py
tests/pytorch/debug/test_distributed.py
+2
-4
tests/pytorch/debug/test_numerics.py
tests/pytorch/debug/test_numerics.py
+45
-2
tests/pytorch/debug/test_sanity.py
tests/pytorch/debug/test_sanity.py
+5
-13
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+6
-7
tests/pytorch/distributed/test_fusible_ops.py
tests/pytorch/distributed/test_fusible_ops.py
+0
-1
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+2
-2
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
+26
-9
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
+2
-0
tests/pytorch/test_checkpoint.py
tests/pytorch/test_checkpoint.py
+175
-0
tests/pytorch/test_fused_router.py
tests/pytorch/test_fused_router.py
+394
-0
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+200
-2
No files found.
tests/cpp/operator/test_multi_unpadding.cu
0 → 100644
View file @
44740c6c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include <cstdio>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/padding.h>
#include "../test_common.h"
using
namespace
transformer_engine
;
namespace
{
template
<
typename
InputType
,
typename
OutputType
>
void
compute_unpadding_ref
(
const
std
::
vector
<
std
::
vector
<
InputType
>>&
input_list
,
std
::
vector
<
std
::
vector
<
OutputType
>>&
output_list
,
const
std
::
vector
<
size_t
>&
height_list
,
const
std
::
vector
<
size_t
>&
width_list
,
const
std
::
vector
<
int
>&
padded_height_list
)
{
using
compute_t
=
float
;
for
(
size_t
tensor_id
=
0
;
tensor_id
<
input_list
.
size
();
++
tensor_id
)
{
const
auto
&
input
=
input_list
[
tensor_id
];
auto
&
output
=
output_list
[
tensor_id
];
const
size_t
height
=
height_list
[
tensor_id
];
const
size_t
width
=
width_list
[
tensor_id
];
const
size_t
padded_height
=
padded_height_list
[
tensor_id
];
// Only copy the valid (unpadded) portion
for
(
size_t
i
=
0
;
i
<
height
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
width
;
++
j
)
{
const
compute_t
x
=
static_cast
<
compute_t
>
(
input
[
i
*
width
+
j
]);
const
OutputType
y
=
static_cast
<
OutputType
>
(
x
);
output
[
i
*
width
+
j
]
=
y
;
}
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
performUnpaddingTest
()
{
using
namespace
test
;
const
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
const
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
const
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>>
tensor_dims
=
{{
1
,
1
},
{
1
,
768
},
{
768
,
1
},
{
768
,
768
},
{
43
,
43
},
{
43
,
256
},
{
256
,
43
},
{
256
,
256
}};
const
size_t
num_tensors
=
tensor_dims
.
size
();
constexpr
int
align
=
16
;
// Buffers for Transformer Engine implementation
std
::
vector
<
Tensor
>
padded_input_list
,
unpadded_output_list
;
// Buffers for reference implementation
std
::
vector
<
std
::
vector
<
InputType
>>
ref_padded_input_list
;
std
::
vector
<
std
::
vector
<
OutputType
>>
ref_unpadded_output_list
;
std
::
vector
<
size_t
>
ref_height_list
(
num_tensors
),
ref_width_list
(
num_tensors
);
std
::
vector
<
int
>
ref_padded_height_list
(
num_tensors
);
// Initialize buffers
for
(
size_t
tensor_id
=
0
;
tensor_id
<
num_tensors
;
++
tensor_id
)
{
const
size_t
original_height
=
tensor_dims
[
tensor_id
].
first
;
const
size_t
width
=
tensor_dims
[
tensor_id
].
second
;
const
size_t
padded_height
=
(
original_height
+
align
-
1
)
/
align
*
align
;
// Input is padded tensor (padded_height x width)
padded_input_list
.
emplace_back
(
Tensor
(
"padded_input_"
+
std
::
to_string
(
tensor_id
),
std
::
vector
<
size_t
>
{
padded_height
,
width
},
itype
));
// Output is unpadded tensor (original_height x width)
unpadded_output_list
.
emplace_back
(
Tensor
(
"unpadded_output_"
+
std
::
to_string
(
tensor_id
),
std
::
vector
<
size_t
>
{
original_height
,
width
},
otype
));
auto
&
padded_input
=
padded_input_list
.
back
();
auto
&
unpadded_output
=
unpadded_output_list
.
back
();
// Fill padded input with random data (including padding area)
fillUniform
(
&
padded_input
);
setRandomScale
(
&
unpadded_output
);
// Initialize reference buffers
ref_padded_input_list
.
emplace_back
(
padded_height
*
width
);
ref_unpadded_output_list
.
emplace_back
(
original_height
*
width
);
// Copy data to reference buffers
std
::
copy
(
padded_input
.
rowwise_cpu_dptr
<
InputType
>
(),
padded_input
.
rowwise_cpu_dptr
<
InputType
>
()
+
padded_height
*
width
,
ref_padded_input_list
.
back
().
begin
());
ref_height_list
[
tensor_id
]
=
original_height
;
ref_width_list
[
tensor_id
]
=
width
;
ref_padded_height_list
[
tensor_id
]
=
padded_height
;
}
// Transformer Engine implementation
auto
make_nvte_vector
=
[](
std
::
vector
<
Tensor
>&
tensor_list
)
->
std
::
vector
<
NVTETensor
>
{
std
::
vector
<
NVTETensor
>
nvte_tensor_list
;
for
(
auto
&
tensor
:
tensor_list
)
{
nvte_tensor_list
.
emplace_back
(
tensor
.
data
());
}
return
nvte_tensor_list
;
};
// Convert height_list to int for the API
std
::
vector
<
int
>
original_height_list_int
(
num_tensors
);
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
original_height_list_int
[
i
]
=
static_cast
<
int
>
(
ref_height_list
[
i
]);
}
// Call unpadding API
nvte_multi_unpadding
(
num_tensors
,
make_nvte_vector
(
padded_input_list
).
data
(),
make_nvte_vector
(
unpadded_output_list
).
data
(),
original_height_list_int
.
data
(),
0
);
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
ASSERT_EQ
(
err
,
cudaSuccess
)
<<
cudaGetErrorString
(
err
);
// Reference implementation
compute_unpadding_ref
<
InputType
,
OutputType
>
(
ref_padded_input_list
,
ref_unpadded_output_list
,
ref_height_list
,
ref_width_list
,
ref_padded_height_list
);
// Check correctness
for
(
size_t
tensor_id
=
0
;
tensor_id
<
num_tensors
;
++
tensor_id
)
{
auto
[
atol
,
rtol
]
=
getTolerances
(
otype
);
compareResults
(
"unpadded_output"
,
unpadded_output_list
[
tensor_id
],
ref_unpadded_output_list
[
tensor_id
].
data
(),
true
,
atol
,
rtol
);
}
}
}
// namespace
class
MultiUnpaddingTestSuite
:
public
::
testing
::
TestWithParam
<
transformer_engine
::
DType
>
{};
TEST_P
(
MultiUnpaddingTestSuite
,
TestMultiUnpadding
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
DType
input_type
=
GetParam
();
const
DType
output_type
=
input_type
;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
output_type
,
OutputType
,
performUnpaddingTest
<
InputType
,
OutputType
>
();
);
);
}
INSTANTIATE_TEST_SUITE_P
(
OperatorTest
,
MultiUnpaddingTestSuite
,
::
testing
::
ValuesIn
(
test
::
all_fp_types
),
[](
const
testing
::
TestParamInfo
<
MultiUnpaddingTestSuite
::
ParamType
>&
info
)
{
std
::
string
name
=
test
::
typeName
(
info
.
param
);
return
name
;
});
tests/cpp/operator/test_normalization.cu
View file @
44740c6c
...
...
@@ -34,8 +34,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return
;
}
if
(
getDeviceComputeCapability
()
<
blackwell
ComputeCapability
&&
use_cudnn
)
{
GTEST_SKIP
()
<<
"cuDNN normalizations not supported on pre-
Blackwell
GPUs yet!"
;
if
(
getDeviceComputeCapability
()
<
hopper
ComputeCapability
&&
use_cudnn
)
{
GTEST_SKIP
()
<<
"cuDNN normalizations not supported on pre-
Hopper
GPUs yet!"
;
}
using
WeightType
=
InputType
;
...
...
tests/cpp/util/test_string.cpp
View file @
44740c6c
...
...
@@ -38,7 +38,7 @@ TEST(UtilTest, ToStringLike) { // to_string_like
// Non-zero integer types
EXPECT_EQ
(
to_string_like
(
static_cast
<
char
>
(
1
)),
"1"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
char
>
(
-
1
)),
"-1"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
signed
char
>
(
-
1
)),
"-1"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
unsigned
char
>
(
2
)),
"2"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
short
>
(
3
)),
"3"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
short
>
(
-
5
)),
"-5"
);
...
...
tests/jax/test_custom_call_compute.py
View file @
44740c6c
...
...
@@ -13,6 +13,7 @@ import operator
from
utils
import
(
assert_allclose
,
pytest_parametrize_wrapper
,
use_jax_gemm
,
)
from
transformer_engine.jax.layernorm
import
layernorm
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
...
...
@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from
transformer_engine.jax.cpp_extensions.misc
import
get_cudnn_version
from
transformer_engine.jax
import
cpp_extensions
as
tex
from
transformer_engine.jax.quantize
import
(
DelayedScaleQuantizer
,
ScaledTensor
,
ScaledTensor1x
,
ScaledTensor2x
,
...
...
@@ -109,8 +109,8 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
else
:
assert_allclose
(
a
.
dequantize
(),
b
,
dtype
=
a
.
data
.
dtype
)
elif
isinstance
(
a
,
ScaledTensor2x
):
assert_dequantized_scaled_tensor
(
a
.
get_
rowwise_tensor
()
,
b
)
assert_dequantized_scaled_tensor
(
a
.
get_
colwise_tensor
()
,
b
)
assert_dequantized_scaled_tensor
(
a
.
rowwise_tensor
,
b
)
assert_dequantized_scaled_tensor
(
a
.
colwise_tensor
,
b
)
else
:
pytest
.
fail
(
"a must be a ScaledTensor object"
)
...
...
@@ -139,10 +139,10 @@ def assert_dequantized_grouped_scaled_tensor(
dq_a_i
=
dq_a_i
.
reshape
(
b_i
.
shape
)
assert_allclose
(
dq_a_i
,
b_i
,
dtype
=
a
.
data
.
dtype
)
elif
isinstance
(
a
,
ScaledTensor2x
):
assert
isinstance
(
a
.
get_
rowwise_tensor
()
,
GroupedScaledTensor1x
)
assert
isinstance
(
a
.
get_
colwise_tensor
()
,
GroupedScaledTensor1x
)
assert_dequantized_grouped_scaled_tensor
(
a
.
get_
rowwise_tensor
()
,
b
)
assert_dequantized_grouped_scaled_tensor
(
a
.
get_
colwise_tensor
()
,
b
)
assert
isinstance
(
a
.
rowwise_tensor
,
GroupedScaledTensor1x
)
assert
isinstance
(
a
.
colwise_tensor
,
GroupedScaledTensor1x
)
assert_dequantized_grouped_scaled_tensor
(
a
.
rowwise_tensor
,
b
)
assert_dequantized_grouped_scaled_tensor
(
a
.
colwise_tensor
,
b
)
else
:
pytest
.
fail
(
"a must be a GroupedScaledTensor object"
)
...
...
@@ -851,6 +851,22 @@ class TestFusedQuantize:
)
valid_fp8_gemm_operand_types
=
[
(
jnp
.
float8_e4m3fn
,
jnp
.
float8_e4m3fn
),
(
jnp
.
float8_e5m2
,
jnp
.
float8_e4m3fn
),
(
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
),
]
def
_use_jax_fp8_gemm
(
enabled
=
False
):
import
os
if
enabled
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS_RE"
]
=
"^(?!GemmPrimitive$).+$"
elif
"NVTE_JAX_CUSTOM_CALLS_RE"
in
os
.
environ
:
os
.
environ
.
pop
(
"NVTE_JAX_CUSTOM_CALLS_RE"
)
class
TestDense
:
def
_ref_gemm_with_jnp_dot
(
self
,
a
,
b
,
data_layout
):
if
data_layout
[
0
]
==
"T"
:
...
...
@@ -883,27 +899,47 @@ class TestDense:
def
test_gemm_bf16
(
self
,
m
,
n
,
k
,
data_layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_layout
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
=
contracting_dims
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"
q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
]
)
@
pytest_parametrize_wrapper
(
"
x_qtype,w_qtype"
,
valid_fp8_gemm_operand_types
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"data_layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
data_layout
):
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_gemm_fp8
(
self
,
m
,
n
,
k
,
x_qtype
,
w_qtype
,
scaling_mode
,
data_layout
,
with_jax_gemm
):
if
(
not
with_jax_gemm
and
scaling_mode
.
is_1d_block_scaling
()
and
jnp
.
float8_e5m2
in
(
x_qtype
,
w_qtype
)
):
pytest
.
skip
(
"Float8E5M2 is not recommended for MXFP8 GEMM."
)
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_layout
)
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
False
scaling_mode
=
scaling_mode
,
fwd_dtype
=
jnp
.
float8_e4m3fn
,
bwd_dtype
=
jnp
.
float8_e5m2
,
is_2x2x
=
False
,
)
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
=
contracting_dims
,
quantizer_set
=
quantizer_set
x
,
w
,
contracting_dims
=
contracting_dims
,
lhs_quantizer
=
(
quantizer_set
.
x
if
x_qtype
==
jnp
.
float8_e4m3fn
else
quantizer_set
.
dgrad
),
rhs_quantizer
=
(
quantizer_set
.
kernel
if
w_qtype
==
jnp
.
float8_e4m3fn
else
quantizer_set
.
dgrad
),
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
float8_e4m3fn
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
def
test_dense_grad_bf16
(
self
,
m
,
n
,
k
):
...
...
@@ -932,9 +968,9 @@ class TestDense:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
def
test_dense_grad_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
):
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_dense_grad_fp8
(
self
,
m
,
n
,
k
,
scaling_mode
,
with_jax_gemm
):
data_layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_layout
)
...
...
@@ -956,10 +992,14 @@ class TestDense:
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
))
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
True
scaling_mode
=
scaling_mode
,
fwd_dtype
=
jnp
.
float8_e4m3fn
,
bwd_dtype
=
jnp
.
float8_e5m2
if
scaling_mode
.
is_tensor_scaling
()
else
jnp
.
float8_e4m3fn
,
is_2x2x
=
True
,
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
for
_
in
range
(
n_iterations
):
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
,
primitive_bias_grad
)
=
(
value_n_grad_primitive_func
(
x
,
w
,
bias
,
contracting_dims
,
quantizer_set
)
...
...
@@ -969,10 +1009,10 @@ class TestDense:
x
,
w
,
bias
,
data_layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_bias_grad
,
ref_bias_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
float8_e4m3fn
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
primitive_bias_grad
,
ref_bias_grad
,
dtype
=
jnp
.
float8_e5m2
)
@
pytest
.
fixture
(
name
=
"random_inputs"
)
...
...
@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class
TestFusedDense
:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
def
test_layernorm_dense_grad
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
norm_type
):
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_dense_grad
(
self
,
m
,
n
,
k
,
scaling_mode
,
norm_type
,
with_jax_gemm
):
"""
Test layernorm_dense VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
in
(
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
,
):
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
zero_centered_gamma
=
False
eps
=
1e-6
...
...
@@ -1025,8 +1058,8 @@ class TestFusedDense:
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
fwd_dtype
=
jnp
.
float8_e4m3fn
,
bwd_dtype
=
jnp
.
float8_e5m2
if
scaling_mode
.
is_tensor_scaling
()
else
jnp
.
float8_e4m3fn
,
is_2x2x
=
True
,
)
...
...
@@ -1064,6 +1097,7 @@ class TestFusedDense:
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_x_grad
,
...
...
@@ -1072,33 +1106,26 @@ class TestFusedDense:
prim_beta_grad
,
)
=
value_n_grad_prim_func
(
x
,
w
,
gamma
,
beta
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_w_grad
,
ref_w_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
jnp
.
float8_e4m3fn
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
prim_w_grad
,
ref_w_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
jnp
.
float8_e5m2
)
if
beta
is
not
None
:
assert_allclose
(
prim_beta_grad
,
ref_beta_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_beta_grad
,
ref_beta_grad
,
dtype
=
jnp
.
float8_e5m2
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_grad
(
self
,
m
,
n
,
k
,
activation_type
,
q_dtype
,
scaling_mode
,
norm_type
,
use_bias
self
,
m
,
n
,
k
,
activation_type
,
scaling_mode
,
norm_type
,
use_bias
,
with_jax_gemm
):
"""
Test layernorm_mlp VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
in
(
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
,
):
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
zero_centered_gamma
=
False
eps
=
1e-6
...
...
@@ -1123,8 +1150,8 @@ class TestFusedDense:
quantizer_sets
=
QuantizerFactory
.
create_set
(
n_quantizer_sets
=
2
,
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
fwd_dtype
=
jnp
.
float8_e4m3fn
,
bwd_dtype
=
jnp
.
float8_e5m2
if
scaling_mode
.
is_tensor_scaling
()
else
jnp
.
float8_e4m3fn
,
is_2x2x
=
True
,
)
...
...
@@ -1153,14 +1180,13 @@ class TestFusedDense:
ln_out
=
_ref_jax_norm_impl
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
=
None
)
# TODO: replace gemm with jnp.dot
linear_1_out
=
tex
.
gemm
(
ln_out
,
kernel_1
,
((
1
,),
(
0
,)))
linear_1_out
=
jax
.
lax
.
dot_general
(
ln_out
,
kernel_1
,
(((
1
,),
(
0
,)),
((),
())))
if
use_bias
:
bias_1_shape
=
(
1
,)
*
(
linear_1_out
.
ndim
-
bias_1
.
ndim
)
+
bias_1
.
shape
linear_1_out
+=
jnp
.
reshape
(
bias_1
,
bias_1_shape
)
x
=
_jax_act_lu
(
linear_1_out
,
activation_type
)
linear_2_out
=
tex
.
gemm
(
x
,
kernel_2
,
((
1
,),
(
0
,)))
linear_2_out
=
jax
.
lax
.
dot_general
(
x
,
kernel_2
,
((
(
1
,),
(
0
,))
,
((),
()))
)
if
use_bias
:
bias_2_shape
=
(
1
,)
*
(
linear_2_out
.
ndim
-
bias_2
.
ndim
)
+
bias_2
.
shape
linear_2_out
+=
jnp
.
reshape
(
bias_2
,
bias_2_shape
)
...
...
@@ -1174,6 +1200,7 @@ class TestFusedDense:
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
range
(
6
))
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_x_grad
,
...
...
@@ -1193,18 +1220,18 @@ class TestFusedDense:
ref_bias_2_grad
,
)
=
value_n_grad_ref_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
jnp
.
float8_e4m3fn
)
assert_allclose
(
prim_kernel_2_grad
,
ref_kernel_2_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_kernel_2_grad
,
ref_kernel_2_grad
,
dtype
=
jnp
.
float8_e5m2
)
if
use_bias
:
assert_allclose
(
prim_bias_2_grad
,
ref_bias_2_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_bias_2_grad
,
ref_bias_2_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
prim_kernel_1_grad
,
ref_kernel_1_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_kernel_1_grad
,
ref_kernel_1_grad
,
dtype
=
jnp
.
float8_e5m2
)
if
use_bias
:
assert_allclose
(
prim_bias_1_grad
,
ref_bias_1_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_bias_1_grad
,
ref_bias_1_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
float8_e5m2
)
# E5M2 * E5M2 is not supported
...
...
@@ -1238,7 +1265,9 @@ class TestGroupedDense:
ref_out
=
[]
dim_num
=
(
contracting_dims
,
((),
()))
for
lhs_i
,
rhs_i
,
bias_i
in
zip
(
lhs
,
rhs
,
bias
):
out_i
=
jax
.
lax
.
dot_general
(
lhs_i
,
rhs_i
,
dim_num
)
+
jnp
.
expand_dims
(
bias_i
,
axis
=
0
)
out_i
=
jax
.
lax
.
dot_general
(
lhs_i
,
rhs_i
,
dim_num
,
precision
=
jax
.
lax
.
Precision
.
HIGHEST
)
+
jnp
.
expand_dims
(
bias_i
,
axis
=
0
)
ref_out
.
append
(
jnp
.
squeeze
(
out_i
))
return
ref_out
...
...
@@ -1250,6 +1279,9 @@ class TestGroupedDense:
group_sizes
=
jnp
.
sort
(
jax
.
random
.
randint
(
subkeys
[
0
],
(
n_groups
-
1
,),
0
,
m
))
group_sizes
=
jnp
.
concatenate
([
jnp
.
array
([
0
]),
group_sizes
,
jnp
.
array
([
m
])])
group_sizes
=
jnp
.
diff
(
group_sizes
)
# Make one empty input lhs to test empty GEMM handling
group_sizes
=
group_sizes
.
at
[
0
].
set
(
group_sizes
[
0
]
+
group_sizes
[
1
])
group_sizes
=
group_sizes
.
at
[
1
].
set
(
0
)
assert
group_sizes
.
sum
()
==
m
# *32 to make sure that input shape works for MXFP8
...
...
@@ -1301,9 +1333,6 @@ class TestGroupedDense:
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"layout"
,
[
"NN"
])
def
test_grouped_gemm_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
input_shape
,
layout
):
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
pytest
.
skip
(
"MXFP8 is not supported in grouped_gemm yet"
)
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
...
...
@@ -1343,9 +1372,10 @@ class TestGroupedDense:
def
_ref_sum_grouped_dense
(
self
,
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
):
out_list
=
self
.
_ref_grouped_dense
(
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
# and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to
# normalize the output and prevent the gradient from being too large for FP8.
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
/
jnp
.
sqrt
(
x
.
size
)
def
_primitive_sum_grouped_dense
(
self
,
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
,
quantizer_set
=
noop_quantizer_set
...
...
@@ -1353,7 +1383,7 @@ class TestGroupedDense:
out
=
grouped_dense
(
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
=
bias
,
quantizer_set
=
quantizer_set
)
return
jnp
.
sum
(
jnp
.
asarray
(
out
))
return
jnp
.
sum
(
jnp
.
asarray
(
out
))
/
jnp
.
sqrt
(
x
.
size
)
@
pytest_parametrize_wrapper
(
"dtype"
,
[
jnp
.
bfloat16
,
jnp
.
float16
])
def
test_grouped_dense_grad_fp16
(
self
,
dtype
,
input_shape
):
...
...
@@ -1388,9 +1418,6 @@ class TestGroupedDense:
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
def
test_grouped_dense_grad_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
input_shape
):
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
pytest
.
skip
(
"MXFP8 is not supported in grouped_dense yet"
)
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
dtype
=
jnp
.
bfloat16
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
=
self
.
_generate_grouped_dense_input
(
...
...
tests/jax/test_distributed_layernorm.py
View file @
44740c6c
...
...
@@ -75,8 +75,6 @@ class TestDistributedLayernorm:
all_reduce_loss_bytes
+
weight_count
*
shape
[
-
1
]
*
jax_dtype
.
itemsize
)
other_bytes
=
0
if
fp8_recipe
==
recipe
.
MXFP8BlockScaling
()
and
"dp"
in
mesh_axes
:
other_bytes
=
384
# required for small scale shapes that require padding
if
fp8_recipe
==
recipe
.
Float8CurrentScaling
():
allreduce_total_bytes
+=
jax_dtype
.
itemsize
# 1 * dtype for the amax reduction
return
generate_collectives_count
(
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
44740c6c
...
...
@@ -13,6 +13,7 @@ from utils import (
assert_tree_like_allclose
,
is_devices_enough
,
pytest_parametrize_wrapper
,
use_jax_gemm
,
)
from
transformer_engine.common
import
recipe
...
...
@@ -33,6 +34,7 @@ from transformer_engine.jax.sharding import (
)
from
transformer_engine.jax.sharding
import
MeshResource
from
transformer_engine.jax.quantize
import
QuantizerFactory
from
transformer_engine.jax.cpp_extensions.misc
import
get_min_device_compute_capability
is_fp8_supported
,
reason
=
is_fp8_available
()
...
...
@@ -146,7 +148,15 @@ class TestDistributedLayernormMLP:
)
def
_test_layernorm_mlp_grad
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
use_shardy
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
use_shardy
,
with_jax_gemm
,
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
...
...
@@ -156,6 +166,8 @@ class TestDistributedLayernormMLP:
input_shape
,
activation_type
,
use_bias
,
dtype
)
static_inputs
=
[
layernorm_type
,
activation_type
]
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
value_and_grad_func
=
jax
.
value_and_grad
(
self
.
layernorm_fp8_mlp_prim_func
,
argnums
=
range
(
len
(
inputs
))
)
...
...
@@ -171,7 +183,9 @@ class TestDistributedLayernormMLP:
# Multi GPUs
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
None
,
"tp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
...
...
@@ -203,25 +217,32 @@ class TestDistributedLayernormMLP:
value_and_grad_func
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
,
static_argnums
=
range
(
len
(
multi_inputs
),
len
(
static_inputs
)
+
len
(
multi_inputs
)
+
1
),
static_argnums
=
range
(
len
(
multi_inputs
),
len
(
static_inputs
)
+
len
(
multi_inputs
)
+
1
),
)
# +1 for multi_gpus
multi_fwd
,
multi_grads
=
multi_jitter
(
*
multi_inputs
,
*
static_inputs
,
True
)
assert_allclose
(
multi_fwd
,
single_fwd
,
dtype
=
dtype
)
fwd_test_type
=
dtype
if
fp8_recipe
is
None
else
jnp
.
float8_e4m3fn
bwd_test_type
=
dtype
if
fp8_recipe
is
None
else
jnp
.
float8_e5m2
assert_allclose
(
multi_fwd
,
single_fwd
,
dtype
=
fwd_test_type
)
for
i
in
range
(
len
(
inputs
)):
if
multi_grads
[
i
]
is
not
None
:
if
isinstance
(
multi_grads
[
i
],
list
):
assert
isinstance
(
single_grads
[
i
],
list
)
for
m_grad
,
s_grad
in
zip
(
multi_grads
[
i
],
single_grads
[
i
]):
assert_allclose
(
m_grad
,
s_grad
,
dtype
=
dtype
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
m_grad
,
s_grad
,
dtype
=
bwd_test_type
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
else
:
assert_allclose
(
multi_grads
[
i
],
single_grads
[
i
],
dtype
=
d
type
,
dtype
=
bwd_test_
type
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
...
...
@@ -232,8 +253,16 @@ class TestDistributedLayernormMLP:
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_grad
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
with_jax_gemm
,
):
self
.
_test_layernorm_mlp_grad
(
mesh_config
,
...
...
@@ -243,6 +272,7 @@ class TestDistributedLayernormMLP:
dtype
,
fp8_recipe
,
use_shardy
=
False
,
with_jax_gemm
=
with_jax_gemm
,
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
...
...
@@ -251,19 +281,29 @@ class TestDistributedLayernormMLP:
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_grad_shardy
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
with_jax_gemm
,
):
# We don't test block scaling with Shardy because at the time of writ
ing
,
# it is not supported in JAX's scaled_matmul_stablehlo.
if
with_jax_gemm
and
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScal
ing
):
pytest
.
skip
(
"`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
self
.
_test_layernorm_mlp_grad
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
=
recipe
.
DelayedScaling
()
,
fp8_recipe
=
fp8_
recipe
,
use_shardy
=
True
,
with_jax_gemm
=
with_jax_gemm
,
)
def
_test_layernorm_mlp
(
...
...
@@ -276,6 +316,7 @@ class TestDistributedLayernormMLP:
use_fp8
,
fp8_recipe
,
use_shardy
,
with_jax_gemm
,
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
batch
,
seqlen
,
hidden_in
=
input_shape
...
...
@@ -287,6 +328,7 @@ class TestDistributedLayernormMLP:
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
batch
,
seqlen
,
hidden_in
),
dtype
)
init_rngs
=
{
"params"
:
subkeys
[
1
]}
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
# Single GPUs
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
ln_mlp_single
=
LayerNormMLP
(
...
...
@@ -333,16 +375,48 @@ class TestDistributedLayernormMLP:
# Make sure params values are the same
assert_tree_like_allclose
(
params_sharded
[
"params"
],
params_single
[
"params"
])
assert_allclose
(
ln_out_sharded
,
ln_out_single
,
dtype
=
dtype
)
assert_allclose
(
mlp_out_sharded
,
mlp_out_single
,
dtype
=
dtype
)
atol
=
None
rtol
=
None
l40_tolerance_update
=
(
get_min_device_compute_capability
()
==
89
and
fp8_recipe
==
recipe
.
DelayedScaling
()
and
use_fp8
and
dtype
==
jnp
.
float16
and
activation_type
==
(
"gelu"
,)
)
if
l40_tolerance_update
:
atol
=
0.04
rtol
=
11
# JAX's FP8 GEMM, jax.lax.dot_general, now uses the
# Triton backend by default. The error of
# the Triton FP8 gemm has been verified to be less than or equal
# to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth.
# However, Triton can auto-tune a different kernel for the single GPU
# and multi-GPU run in this test, meaning the diff between single GPU
# and multi-GPU can be larger in some cases, even though both are
# within tolerance to the float32 ground truth.
jax_triton_gemm_precision_tolerance_update
=
(
with_jax_gemm
and
isinstance
(
fp8_recipe
,
recipe
.
Float8CurrentScaling
)
and
dtype
==
jnp
.
bfloat16
and
activation_type
==
(
"gelu"
,
"linear"
)
)
if
jax_triton_gemm_precision_tolerance_update
:
atol
=
0.08
rtol
=
15
assert_allclose
(
mlp_out_sharded
,
mlp_out_single
,
dtype
=
dtype
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"
use_shardy
"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"
with_jax_gemm
"
,
[
False
,
True
])
def
test_layernorm_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_shardy
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
with_jax_gemm
):
self
.
_test_layernorm_mlp
(
mesh_config
,
...
...
@@ -352,7 +426,8 @@ class TestDistributedLayernormMLP:
dtype
,
use_fp8
=
False
,
fp8_recipe
=
None
,
use_shardy
=
use_shardy
,
use_shardy
=
False
,
with_jax_gemm
=
with_jax_gemm
,
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
...
...
@@ -362,8 +437,9 @@ class TestDistributedLayernormMLP:
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_layer_fp8
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
with_jax_gemm
):
self
.
_test_layernorm_mlp
(
mesh_config
,
...
...
@@ -374,4 +450,51 @@ class TestDistributedLayernormMLP:
use_fp8
=
True
,
fp8_recipe
=
fp8_recipe
,
use_shardy
=
False
,
with_jax_gemm
=
with_jax_gemm
,
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_layer_shardy
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
with_jax_gemm
):
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
False
,
fp8_recipe
=
None
,
use_shardy
=
True
,
with_jax_gemm
=
with_jax_gemm
,
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_layer_fp8_shardy
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
with_jax_gemm
):
if
with_jax_gemm
and
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScaling
):
pytest
.
skip
(
"`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
True
,
fp8_recipe
=
fp8_recipe
,
use_shardy
=
True
,
with_jax_gemm
=
with_jax_gemm
,
)
tests/jax/test_helper.py
View file @
44740c6c
...
...
@@ -92,7 +92,7 @@ class TestFP8Functions(unittest.TestCase):
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
def
test_fp8_autocast_
mxfp8
_scaling
(
self
):
def
test_fp8_autocast_
current
_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
...
...
@@ -116,7 +116,7 @@ class TestFP8Functions(unittest.TestCase):
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
def
test_fp8_autocast_mxfp8_scaling
(
self
):
def
test_fp8_autocast_mxfp8_
block_
scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
...
...
tests/jax/utils.py
View file @
44740c6c
...
...
@@ -3,11 +3,12 @@
# See LICENSE for license information.
"""Utility for the TE layer tests"""
import
os
import
functools
import
math
import
operator
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
Sequence
,
Union
,
Iterable
,
Optional
import
os
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
Sequence
,
Union
,
Iterable
,
Optional
,
NewType
from
contextlib
import
contextmanager
import
jax
import
jax.numpy
as
jnp
...
...
@@ -20,7 +21,6 @@ from jax import random as jax_random
import
pytest
from
transformer_engine.jax.attention
import
(
AttnMaskType
,
canonicalize_attn_mask_type
,
make_swa_mask
,
)
...
...
@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType
PRNGKey
=
Any
Shape
=
Tuple
[
int
,
...]
DType
=
jnp
.
dtype
Array
=
Any
DType
=
NewType
(
"DType"
,
jnp
.
dtype
)
Array
=
NewType
(
"Array"
,
jnp
.
ndarray
)
PrecisionLike
=
Union
[
None
,
str
,
lax
.
Precision
,
Tuple
[
str
,
str
],
Tuple
[
lax
.
Precision
,
lax
.
Precision
]
]
...
...
@@ -1519,7 +1519,7 @@ def dtype_tols(
TEDType
.
kFloat8E5M2
:
jnp
.
float8_e5m2
,
}[
dtype
]
elif
isinstance
(
dtype
,
np
.
dtype
):
dtype
=
jnp
.
dt
ype
(
dtype
)
dtype
=
DT
ype
(
dtype
)
# Expect bit-wise accuracy for integer dtypes
if
not
jnp
.
issubdtype
(
dtype
,
jnp
.
floating
):
...
...
@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
fmt
=
fmt
+
"
\n
{}
\n
{}"
jax
.
debug
.
print
(
fmt
,
*
args
)
@
contextmanager
def
use_jax_gemm
(
enabled
=
False
):
orig_custom_calls_filter
=
os
.
environ
.
get
(
"NVTE_JAX_CUSTOM_CALLS_RE"
,
None
)
try
:
if
enabled
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS_RE"
]
=
"^(?!GemmPrimitive$).+$"
yield
finally
:
if
enabled
:
if
orig_custom_calls_filter
is
None
:
os
.
environ
.
pop
(
"NVTE_JAX_CUSTOM_CALLS_RE"
)
else
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS_RE"
]
=
orig_custom_calls_filter
tests/pytorch/debug/run_distributed.py
View file @
44740c6c
...
...
@@ -16,7 +16,7 @@ import transformer_engine
import
transformer_engine_torch
as
tex
import
nvdlfw_inspect.api
as
debug_api
from
transformer_engine.debug
import
set_weight_tensor_tp_group_reduce
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
test_numerics
import
(
_emulate_linear
,
...
...
@@ -45,6 +45,8 @@ FEATURE_DIRS = None
all_boolean
=
[
True
,
False
]
TEST_NR
=
0
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
def
_get_tensors
(
parallel_mode
,
weight_seed
=
SEED
,
data_seed
=
SEED
,
tp_size
=
None
,
tp_rank
=
None
):
if
tp_size
is
None
:
...
...
@@ -221,7 +223,7 @@ def run_debug_test(func):
return
wrapper
CONFIG_LOG_TEST_DISTRIBUTED
=
"""log_distributed:
CONFIG_LOG_TEST_DISTRIBUTED
_FP8
=
"""log_distributed:
layers:
layer_types: [linear]
enabled:
...
...
@@ -241,11 +243,27 @@ CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
end_step: 1
"""
CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8
=
"""log_distributed:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
"""
def
_prepare_config_test_log_distributed
(
config_file
):
if
WORLD_RANK
!=
0
:
return
config_file
.
write
(
CONFIG_LOG_TEST_DISTRIBUTED
)
config_file
.
write
(
CONFIG_LOG_TEST_DISTRIBUTED_FP8
if
fp8_available
else
CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8
)
config_file
.
flush
()
...
...
@@ -361,13 +379,13 @@ def test_log_expert_parallel(**kwargs):
)
# data parallel
model
=
_init_model
(
weight
,
parallel_mode
=
None
,
name
=
"linear1"
)
model1
=
_init_model
(
weight
,
parallel_mode
=
None
,
name
=
"linear2"
)
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
Tru
e
,
fp8_recipe
=
FP8_RECIPE
):
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
fp8_availabl
e
,
fp8_recipe
=
FP8_RECIPE
):
y1
=
model
(
x
)
y2
=
model1
(
x
)
y
=
y1
+
y2
y
.
sum
().
backward
()
debug_api
.
step
()
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
Tru
e
,
fp8_recipe
=
FP8_RECIPE
):
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
fp8_availabl
e
,
fp8_recipe
=
FP8_RECIPE
):
y
=
model
(
x
)
if
WORLD_RANK
!=
0
:
y
=
y
+
model1
(
x
)
...
...
@@ -620,6 +638,7 @@ if __name__ == "__main__":
for
gather_weight
in
[
True
,
False
]:
test_log_distributed
(
parallel_mode
,
gather_weight
)
if
fp8_available
:
for
parallel_mode
in
[
"row"
,
"column"
]:
test_disable_fp8_layer
(
parallel_mode
)
...
...
tests/pytorch/debug/test_distributed.py
View file @
44740c6c
...
...
@@ -5,7 +5,6 @@
import
os
import
subprocess
from
pathlib
import
Path
import
pytest
import
torch
...
...
@@ -21,7 +20,6 @@ import torch
"""
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"Distributed training needs at least 2 GPUs."
)
...
...
@@ -34,6 +32,6 @@ def test_debug_distributed(feature_dirs):
test_path
=
TEST_ROOT
/
"run_distributed.py"
test_cmd
=
LAUNCH_CMD
+
[
str
(
test_path
),
f
"--feature_dirs=
{
feature_dirs
[
0
]
}
"
]
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
c
apture_output
=
True
,
check
=
Fals
e
)
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
c
heck
=
False
,
text
=
Tru
e
)
if
result
.
returncode
!=
0
:
raise
AssertionError
(
result
.
stderr
.
de
code
()
)
raise
AssertionError
(
f
"torchrun exited with
{
result
.
return
code
}
"
)
tests/pytorch/debug/test_numerics.py
View file @
44740c6c
...
...
@@ -27,6 +27,9 @@ from transformer_engine.pytorch.module.base import (
_2X_ACC_FPROP
,
_2X_ACC_WGRAD
,
)
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
all_boolean
=
[
True
,
False
]
FP8_FORMAT
=
Format
.
HYBRID
...
...
@@ -246,8 +249,8 @@ def _init_model(weight):
return
model
def
_run_forward_backward
(
x
,
model
,
loss_scale
=
1.0
,
is_first_microbatch
=
None
):
with
tepytorch
.
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
FP8_RECIPE
):
def
_run_forward_backward
(
x
,
model
,
loss_scale
=
1.0
,
is_first_microbatch
=
None
,
fp8
=
True
):
with
tepytorch
.
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
FP8_RECIPE
):
y
=
model
(
x
,
is_first_microbatch
=
is_first_microbatch
)
(
y
.
sum
()
*
loss_scale
).
backward
()
debug_api
.
step
()
...
...
@@ -262,6 +265,18 @@ def _get_tensors():
return
x
,
weight
LOGGING_CONFIG
=
"""logging_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
"""
DISABLE_FP8_CONFIG
=
Template
(
"""disable_fp8_config:
enabled: True
...
...
@@ -275,10 +290,30 @@ DISABLE_FP8_CONFIG = Template(
)
@
create_config_file
def
run_logging_zero_numel_tensor
(
feature_dirs
,
**
kwargs
):
kwargs
[
"config_file"
].
write
(
LOGGING_CONFIG
)
kwargs
[
"config_file"
].
flush
()
_init_debug
(
kwargs
[
"config_file"
].
name
,
kwargs
[
"log_dir"
],
feature_dirs
)
x
,
weight
=
_get_tensors
()
x1
=
x
[:
0
,
:]
model
=
_init_model
(
weight
)
_
=
_run_forward_backward
(
x1
,
model
,
fp8
=
False
)
_
=
_run_forward_backward
(
x
,
model
,
fp8
=
False
)
def
test_logging_zero_numel_tensor
(
feature_dirs
):
run_logging_zero_numel_tensor
(
feature_dirs
)
@
pytest
.
mark
.
parametrize
(
"fprop_fp8"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"dgrad_fp8"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"wgrad_fp8"
,
all_boolean
)
def
test_disable_fp8_gemms
(
feature_dirs
,
fprop_fp8
,
dgrad_fp8
,
wgrad_fp8
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
run_disable_fp8_gemms
(
feature_dirs
,
fprop_fp8
,
dgrad_fp8
,
wgrad_fp8
)
...
...
@@ -318,6 +353,8 @@ def run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8, **kwarg
def
test_disable_fp8_layer
(
feature_dirs
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
run_disable_fp8_layer
(
feature_dirs
)
...
...
@@ -363,6 +400,8 @@ subset_combinations = random.sample(all_combinations, 20)
def
test_per_tensor_scaling
(
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
not
any
([
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
]):
pytest
.
skip
(
"Skipping test because all parameters are False"
)
run_per_tensor_scaling
(
...
...
@@ -535,6 +574,8 @@ def run_per_tensor_scaling(
def
test_microbatching_per_tensor_scaling
(
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
not
any
([
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
]):
pytest
.
skip
(
"Skipping test because all parameters are False"
)
...
...
@@ -624,6 +665,8 @@ subset_combinations = random.sample(all_combinations, 10)
def
test_fake_quant_fp8
(
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
run_fake_quant_fp8
(
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
)
...
...
tests/pytorch/debug/test_sanity.py
View file @
44740c6c
...
...
@@ -2,27 +2,17 @@
#
# See LICENSE for license information.
import
functools
import
itertools
import
os
import
random
import
tempfile
from
string
import
Template
import
pytest
import
torch
import
nvdlfw_inspect.api
as
debug_api
import
transformer_engine.debug
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
DelayedScaling
,
Format
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
_default_sf_compute
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
test_numerics
import
create_config_file
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
B
,
S
,
H
,
D
=
64
,
64
,
64
,
64
model_keys
=
[
"linear"
,
"layernorm_linear"
,
"layernorm_mlp"
,
"mha_attention"
,
"transformer_layer"
]
...
...
@@ -104,4 +94,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
@
pytest
.
mark
.
parametrize
(
"fp8"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"config_key"
,
configs
.
keys
())
def
test_sanity_debug
(
model_key
,
fp8
,
config_key
,
feature_dirs
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
_run_test
(
model_key
,
fp8
,
configs
[
config_key
],
feature_dirs
)
tests/pytorch/distributed/run_numerics.py
View file @
44740c6c
...
...
@@ -48,11 +48,6 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
)
# Disable TF32
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
# Quantization recipe setup
def
quantization_recipe
()
->
Recipe
:
if
QUANTIZATION
==
"fp8"
:
...
...
@@ -167,7 +162,7 @@ def _gather(tensor, dim=0):
def
_constant
(
tensor
):
return
nn
.
init
.
constant_
(
tensor
,
0.5
)
return
nn
.
init
.
constant_
(
tensor
,
0.
0
5
)
def
dist_print
(
msg
,
src
=
None
,
end
=
"
\n
"
,
error
=
False
):
...
...
@@ -190,7 +185,8 @@ def _get_tolerances(dtype):
if
dtype
==
torch
.
bfloat16
:
return
{
"rtol"
:
1.6e-2
,
"atol"
:
1e-5
}
if
dtype
==
torch
.
float32
:
return
{
"rtol"
:
1.2e-4
,
"atol"
:
1e-4
}
# TF32 has same mantissa bits as FP16
return
{
"rtol"
:
1e-3
,
"atol"
:
1e-5
}
raise
ValueError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
...
...
@@ -521,8 +517,11 @@ def test_linear():
{
"return_bias"
:
True
},
{
"params_dtype"
:
torch
.
float16
},
{
"delay_wgrad_compute"
:
True
},
{
"save_original_input"
:
True
},
]
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
continue
for
parallel_mode
in
[
"column"
,
"row"
]:
for
sequence_parallel
in
[
False
,
True
]:
_test_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
...
...
tests/pytorch/distributed/test_fusible_ops.py
View file @
44740c6c
...
...
@@ -28,7 +28,6 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine_torch
as
tex
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
...
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
44740c6c
...
...
@@ -21,7 +21,6 @@ import transformer_engine.pytorch as te
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
from
transformer_engine.pytorch.ops.fused
import
(
UserbuffersBackwardLinear
,
UserbuffersForwardLinear
,
...
...
@@ -32,6 +31,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
# Import utility functions
...
...
@@ -370,7 +370,7 @@ def _test_linear(
if
quantized_compute
:
tols
=
dtype_tols
(
model
[
0
].
weight
.
_fp8_dtype
if
is
_float8_tensor
(
model
[
0
].
weight
)
if
is
instance
(
model
[
0
].
weight
,
Float8Tensor
)
else
tex
.
DType
.
kFloat8E4M3
)
...
...
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
View file @
44740c6c
...
...
@@ -89,7 +89,7 @@ def run_dpa_with_cp(
# instantiate core attn module
core_attn
=
DotProductAttention
(
config
.
num_heads
,
config
.
head_dim_qk
,
(
config
.
head_dim_qk
,
config
.
head_dim_v
),
num_gqa_groups
=
config
.
num_gqa_groups
,
attention_dropout
=
config
.
dropout_p
,
qkv_format
=
qkv_format
,
...
...
@@ -106,16 +106,22 @@ def run_dpa_with_cp(
config
.
num_heads
,
config
.
head_dim_qk
,
)
k
v
_input_shape
=
(
k_input_shape
=
(
config
.
batch_size
,
config
.
max_seqlen_kv
,
config
.
num_gqa_groups
,
config
.
head_dim_qk
,
)
v_input_shape
=
(
config
.
batch_size
,
config
.
max_seqlen_kv
,
config
.
num_gqa_groups
,
config
.
head_dim_v
,
)
attn_output_shape
=
(
config
.
batch_size
,
config
.
max_seqlen_q
,
config
.
num_heads
*
config
.
head_dim_
qk
,
config
.
num_heads
*
config
.
head_dim_
v
,
)
cu_seqlens_q
=
None
cu_seqlens_kv
=
None
...
...
@@ -128,16 +134,22 @@ def run_dpa_with_cp(
config
.
num_heads
,
config
.
head_dim_qk
,
)
k
v
_input_shape
=
(
k_input_shape
=
(
config
.
max_seqlen_kv
,
config
.
batch_size
,
config
.
num_gqa_groups
,
config
.
head_dim_qk
,
)
v_input_shape
=
(
config
.
max_seqlen_kv
,
config
.
batch_size
,
config
.
num_gqa_groups
,
config
.
head_dim_v
,
)
attn_output_shape
=
(
config
.
max_seqlen_q
,
config
.
batch_size
,
config
.
num_heads
*
config
.
head_dim_
qk
,
config
.
num_heads
*
config
.
head_dim_
v
,
)
cu_seqlens_q
=
None
cu_seqlens_kv
=
None
...
...
@@ -149,14 +161,19 @@ def run_dpa_with_cp(
config
.
num_heads
,
config
.
head_dim_qk
,
)
k
v
_input_shape
=
(
k_input_shape
=
(
config
.
batch_size
*
config
.
max_seqlen_q
,
config
.
num_gqa_groups
,
config
.
head_dim_qk
,
)
v_input_shape
=
(
config
.
batch_size
*
config
.
max_seqlen_q
,
config
.
num_gqa_groups
,
config
.
head_dim_v
,
)
attn_output_shape
=
(
config
.
batch_size
*
config
.
max_seqlen_q
,
config
.
num_heads
*
config
.
head_dim_
qk
,
config
.
num_heads
*
config
.
head_dim_
v
,
)
seqlens_q
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
+
1
,
[
config
.
batch_size
]).
to
(
torch
.
int32
)
seqlens_q_padded
=
(
seqlens_q
+
2
*
world_size
-
1
)
//
(
world_size
*
2
)
*
(
world_size
*
2
)
...
...
@@ -177,8 +194,8 @@ def run_dpa_with_cp(
assert
False
,
f
"
{
qkv_format
}
is an unsupported qkv_format!"
q
=
torch
.
randn
(
q_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
k
=
torch
.
randn
(
k
v
_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
v
=
torch
.
randn
(
k
v_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
k
=
torch
.
randn
(
k_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
v
=
torch
.
randn
(
v_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
dout
=
torch
.
randn
(
attn_output_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
dout_quantizer
=
Float8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
,
...
...
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
View file @
44740c6c
...
...
@@ -174,6 +174,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest
.
skip
(
"Only fp8 works with fp8_mha=True!"
)
if
"p2p"
not
in
cp_comm_type
and
config
.
head_dim_qk
!=
config
.
head_dim_v
:
pytest
.
skip
(
"MLA CP currently only support KV P2P!"
)
if
dtype
==
"fp8"
and
config
.
head_dim_qk
!=
config
.
head_dim_v
:
pytest
.
skip
(
"MLA CP currently does not support FP8 attention!"
)
subprocess
.
run
(
get_bash_arguments
(
...
...
tests/pytorch/test_checkpoint.py
0 → 100644
View file @
44740c6c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
__future__
import
annotations
import
argparse
import
functools
import
os
import
pathlib
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
from
utils
import
make_recipe
# Check supported quantization schemes
fp8_available
,
reason_for_no_fp8
=
te
.
fp8
.
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
fp8
.
FP8GlobalStateManager
.
is_mxfp8_available
()
# Test cases for loading checkpoint files
_TestLoadCheckpoint_name_list
:
tuple
[
str
,
...]
=
(
"linear"
,
"layernorm_linear"
,
"layernorm_mlp"
,
"layernorm"
,
"rmsnorm"
,
"transformer_layer"
,
"ops_linear"
,
"linear.fp8"
,
"ops_linear.fp8"
,
"linear.mxfp8"
,
"ops_linear.mxfp8"
,
)
class
TestLoadCheckpoint
:
"""Tests for loading checkpoint files
Tests assume that checkpoint files have already been created. In
order to regenerate checkpoint files, e.g. after a breaking change
in the checkpoint format, run this file directly as a Python
script: `python3 test_checkpoint.py --save-checkpoint all`.
"""
@
staticmethod
def
_make_module
(
name
:
str
)
->
torch
.
nn
.
Module
:
"""Construct a module"""
if
name
==
"linear"
:
return
te
.
Linear
(
1
,
1
)
if
name
==
"layernorm_linear"
:
return
te
.
LayerNormLinear
(
1
,
1
)
if
name
==
"layernorm_mlp"
:
return
te
.
LayerNormMLP
(
1
,
1
)
if
name
==
"layernorm"
:
return
te
.
LayerNorm
(
1
)
if
name
==
"rmsnorm"
:
return
te
.
RMSNorm
(
1
)
if
name
==
"transformer_layer"
:
return
te
.
TransformerLayer
(
1
,
1
,
1
)
if
name
==
"ops_linear"
:
return
te
.
ops
.
Linear
(
1
,
1
)
if
name
==
"linear.fp8"
:
with
te
.
fp8_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
return
te
.
Linear
(
16
,
16
)
if
name
==
"ops_linear.fp8"
:
with
te
.
fp8_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
return
te
.
ops
.
Linear
(
16
,
16
)
if
name
==
"linear.mxfp8"
:
with
te
.
fp8_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
return
te
.
Linear
(
32
,
32
)
if
name
==
"ops_linear.mxfp8"
:
with
te
.
fp8_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
return
te
.
ops
.
Linear
(
32
,
32
)
raise
ValueError
(
f
"Unrecognized module name (
{
name
}
)"
)
@
staticmethod
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_checkpoint_dir
()
->
pathlib
.
Path
:
"""Path to directory with checkpoint files"""
# Check environment variable
path
=
os
.
getenv
(
"NVTE_TEST_CHECKPOINT_ARTIFACT_PATH"
)
if
path
:
return
pathlib
.
Path
(
path
).
resolve
()
# Fallback to path in root dir
root_dir
=
pathlib
.
Path
(
__file__
).
resolve
().
parent
.
parent
.
parent
return
root_dir
/
"artifacts"
/
"tests"
/
"pytorch"
/
"test_checkpoint"
@
staticmethod
def
_save_checkpoint
(
name
:
str
,
checkpoint_dir
:
Optional
[
pathlib
.
Path
]
=
None
)
->
None
:
"""Save a module's checkpoint file"""
# Path to save checkpoint
if
checkpoint_dir
is
None
:
checkpoint_dir
=
TestLoadCheckpoint
.
_checkpoint_dir
()
checkpoint_dir
.
mkdir
(
exist_ok
=
True
)
checkpoint_file
=
checkpoint_dir
/
f
"
{
name
}
.pt"
# Create module and save checkpoint
module
=
TestLoadCheckpoint
.
_make_module
(
name
)
torch
.
save
(
module
.
state_dict
(),
checkpoint_file
)
print
(
f
"Saved checkpoint for
{
name
}
at
{
checkpoint_file
}
"
)
@
pytest
.
mark
.
parametrize
(
"name"
,
_TestLoadCheckpoint_name_list
)
def
test_module
(
self
,
name
:
str
)
->
None
:
"""Test for loading a module's checkpoint file"""
# Skip if quantization is not supported
quantization
=
None
if
"."
in
name
:
quantization
=
name
.
split
(
"."
)[
1
]
if
quantization
==
"fp8"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Construct module
module
=
self
.
_make_module
(
name
)
# Load checkpoint from file
checkpoint_file
=
self
.
_checkpoint_dir
()
/
f
"
{
name
}
.pt"
if
not
checkpoint_file
.
is_file
():
raise
FileNotFoundError
(
f
"Could not find checkpoint file at
{
checkpoint_file
}
"
)
state_dict
=
torch
.
load
(
checkpoint_file
,
weights_only
=
False
)
# Update module from checkpoint
module
.
load_state_dict
(
state_dict
,
strict
=
True
)
def
main
()
->
None
:
"""Main function
Typically used to generate checkpoint files.
"""
# Parse command-line arguments
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save-checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Save checkpoint file for a module"
,
)
parser
.
add_argument
(
"--checkpoint-dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory to save checkpoint file in"
,
)
args
=
parser
.
parse_args
()
# Save checkpoint files if needed
if
args
.
save_checkpoint
is
not
None
:
checkpoint_dir
=
args
.
checkpoint_dir
if
checkpoint_dir
is
not
None
:
checkpoint_dir
=
pathlib
.
Path
(
checkpoint_dir
).
resolve
()
if
args
.
save_checkpoint
==
"all"
:
for
name
in
_TestLoadCheckpoint_name_list
:
TestLoadCheckpoint
.
_save_checkpoint
(
name
,
checkpoint_dir
=
checkpoint_dir
)
else
:
TestLoadCheckpoint
.
_save_checkpoint
(
args
.
save_checkpoint
,
checkpoint_dir
=
checkpoint_dir
,
)
if
__name__
==
"__main__"
:
main
()
tests/pytorch/test_fused_router.py
0 → 100644
View file @
44740c6c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
torch
import
math
from
typing
import
Optional
,
Dict
from
transformer_engine.pytorch.router
import
(
fused_topk_with_score_function
,
fused_compute_score_for_moe_aux_loss
,
fused_moe_aux_loss
,
)
import
pytest
from
copy
import
deepcopy
seed
=
42
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
# Pytorch-based group topk
def
group_limited_topk
(
scores
:
torch
.
Tensor
,
topk
:
int
,
num_tokens
:
int
,
num_experts
:
int
,
num_groups
:
int
,
group_topk
:
int
,
):
group_scores
=
(
scores
.
view
(
num_tokens
,
num_groups
,
-
1
).
topk
(
topk
//
group_topk
,
dim
=-
1
)[
0
].
sum
(
dim
=-
1
)
)
group_idx
=
torch
.
topk
(
group_scores
,
k
=
group_topk
,
dim
=-
1
,
sorted
=
False
)[
1
]
group_mask
=
torch
.
zeros_like
(
group_scores
)
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# Mask the experts based on selection groups
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
num_tokens
,
num_groups
,
num_experts
//
num_groups
)
.
reshape
(
num_tokens
,
-
1
)
)
masked_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
float
(
"-inf"
))
probs
,
top_indices
=
torch
.
topk
(
masked_scores
,
k
=
topk
,
dim
=-
1
)
return
probs
,
top_indices
# Pytorch-based topk softmax/sigmoid
def
topk_softmax_sigmoid_pytorch
(
logits
:
torch
.
Tensor
,
topk
:
int
,
use_pre_softmax
:
bool
=
False
,
num_groups
:
Optional
[
int
]
=
None
,
group_topk
:
Optional
[
int
]
=
None
,
scaling_factor
:
Optional
[
float
]
=
None
,
score_function
:
str
=
"softmax"
,
expert_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
num_tokens
,
num_experts
=
logits
.
shape
def
compute_topk
(
scores
,
topk
,
num_groups
=
None
,
group_topk
=
None
):
if
group_topk
:
return
group_limited_topk
(
scores
=
scores
,
topk
=
topk
,
num_tokens
=
num_tokens
,
num_experts
=
num_experts
,
num_groups
=
num_groups
,
group_topk
=
group_topk
,
)
else
:
return
torch
.
topk
(
scores
,
k
=
topk
,
dim
=
1
)
if
score_function
==
"softmax"
:
if
use_pre_softmax
:
scores
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
).
type_as
(
logits
)
probs
,
top_indices
=
compute_topk
(
scores
,
topk
,
num_groups
,
group_topk
)
else
:
scores
,
top_indices
=
compute_topk
(
logits
,
topk
,
num_groups
,
group_topk
)
probs
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
torch
.
float32
).
type_as
(
logits
)
elif
score_function
==
"sigmoid"
:
scores
=
torch
.
sigmoid
(
logits
.
float
()).
type_as
(
logits
)
if
expert_bias
is
not
None
:
scores_for_routing
=
scores
+
expert_bias
_
,
top_indices
=
compute_topk
(
scores_for_routing
,
topk
,
num_groups
,
group_topk
)
scores
=
torch
.
gather
(
scores
,
dim
=
1
,
index
=
top_indices
).
type_as
(
logits
)
else
:
scores
,
top_indices
=
compute_topk
(
scores
,
topk
,
num_groups
,
group_topk
)
probs
=
scores
/
(
scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
)
if
topk
>
1
else
scores
else
:
raise
ValueError
(
f
"Invalid score_function:
{
score_function
}
"
)
if
scaling_factor
:
probs
=
probs
*
scaling_factor
topk_masked_gates
=
torch
.
zeros_like
(
logits
).
scatter
(
1
,
top_indices
,
probs
)
topk_map
=
torch
.
zeros_like
(
logits
).
int
().
scatter
(
1
,
top_indices
,
1
).
bool
()
return
topk_masked_gates
,
topk_map
# Pytorch-based compute routing scores for aux loss
def
compute_scores_for_aux_loss_pytorch
(
logits
:
torch
.
Tensor
,
topk
:
int
,
score_function
:
str
)
->
torch
.
Tensor
:
if
score_function
==
"softmax"
:
scores
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
elif
score_function
==
"sigmoid"
:
scores
=
torch
.
sigmoid
(
logits
)
scores
=
scores
/
(
scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
)
if
topk
>
1
else
scores
else
:
raise
ValueError
(
f
"Invalid score_function:
{
score_function
}
"
)
_
,
top_indices
=
torch
.
topk
(
scores
,
k
=
topk
,
dim
=
1
)
routing_map
=
torch
.
zeros_like
(
logits
).
int
().
scatter
(
1
,
top_indices
,
1
).
bool
()
return
routing_map
,
scores
# Pytorch-based aux loss
def
aux_loss_pytorch
(
probs
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
total_num_tokens
:
int
,
topk
:
int
,
num_experts
:
int
,
moe_aux_loss_coeff
:
float
,
):
aggregated_probs_per_expert
=
probs
.
sum
(
dim
=
0
)
aux_loss
=
torch
.
sum
(
aggregated_probs_per_expert
*
tokens_per_expert
)
*
(
num_experts
*
moe_aux_loss_coeff
/
(
topk
*
total_num_tokens
*
total_num_tokens
)
)
return
aux_loss
def
run_comparison
(
dtype
,
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
num_groups
,
group_topk
,
scaling_factor
,
score_function
,
enable_bias
,
):
# Set some parameters
if
score_function
==
"sigmoid"
:
# Construct the special logits to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
logits
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
logits
=
logits
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
else
:
logits
=
torch
.
arange
(
num_tokens
*
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-4
logits
=
logits
.
view
(
num_tokens
,
num_experts
)
logits
.
requires_grad
=
True
if
enable_bias
and
score_function
==
"sigmoid"
:
expert_bias
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
)
*
0.1
expert_bias
=
torch
.
flip
(
expert_bias
,
dims
=
[
0
])
expert_bias
.
requires_grad
=
True
else
:
expert_bias
=
None
# Clone the input tensor
logits_clone
=
deepcopy
(
logits
)
logits_clone
.
requires_grad
=
True
if
expert_bias
is
not
None
:
expert_bias_clone
=
deepcopy
(
expert_bias
)
expert_bias_clone
.
requires_grad
=
True
else
:
expert_bias_clone
=
None
# Run the original implementation
# We do not support the capacity factor case
probs
,
routing_map
=
topk_softmax_sigmoid_pytorch
(
logits
=
logits
,
topk
=
topk
,
use_pre_softmax
=
use_pre_softmax
,
num_groups
=
num_groups
,
group_topk
=
group_topk
,
scaling_factor
=
scaling_factor
,
score_function
=
score_function
,
expert_bias
=
expert_bias
,
)
# Run the fused implementation
probs_fused
,
routing_map_fused
=
fused_topk_with_score_function
(
logits
=
logits_clone
,
topk
=
topk
,
use_pre_softmax
=
use_pre_softmax
,
num_groups
=
num_groups
,
group_topk
=
group_topk
,
scaling_factor
=
scaling_factor
,
score_function
=
score_function
,
expert_bias
=
expert_bias_clone
,
)
torch
.
testing
.
assert_close
(
probs
,
probs_fused
)
torch
.
testing
.
assert_close
(
routing_map
,
routing_map_fused
)
# Fake the loss
loss
=
torch
.
sum
(
probs
)
loss_fused
=
torch
.
sum
(
probs_fused
)
# Backward the loss
loss
.
backward
()
loss_fused
.
backward
()
# Check the gradient
torch
.
testing
.
assert_close
(
logits
.
grad
,
logits_clone
.
grad
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
,
7168
,
8992
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
128
,
32
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"group_topk"
,
[
None
,
4
])
@
pytest
.
mark
.
parametrize
(
"scaling_factor"
,
[
None
,
1.2
])
@
pytest
.
mark
.
parametrize
(
"enable_bias"
,
[
True
,
False
])
def
test_topk_sigmoid
(
dtype
,
num_tokens
,
num_experts
,
topk
,
group_topk
,
scaling_factor
,
enable_bias
,
):
num_groups
=
8
if
group_topk
else
None
run_comparison
(
dtype
=
dtype
,
num_tokens
=
num_tokens
,
num_experts
=
num_experts
,
topk
=
topk
,
use_pre_softmax
=
False
,
num_groups
=
num_groups
,
group_topk
=
group_topk
,
scaling_factor
=
scaling_factor
,
score_function
=
"sigmoid"
,
enable_bias
=
enable_bias
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
,
7168
,
14234
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
128
,
32
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"use_pre_softmax"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"group_topk"
,
[
None
,
4
])
@
pytest
.
mark
.
parametrize
(
"scaling_factor"
,
[
None
,
1.2
])
def
test_topk_softmax
(
dtype
,
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
group_topk
,
scaling_factor
,
):
num_groups
=
8
if
group_topk
else
None
run_comparison
(
dtype
=
dtype
,
num_tokens
=
num_tokens
,
num_experts
=
num_experts
,
topk
=
topk
,
use_pre_softmax
=
use_pre_softmax
,
num_groups
=
num_groups
,
group_topk
=
group_topk
,
scaling_factor
=
scaling_factor
,
score_function
=
"softmax"
,
enable_bias
=
False
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
,
7168
,
14234
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
256
,
128
,
32
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"score_function"
,
[
"softmax"
,
"sigmoid"
])
def
test_fused_scores_for_aux_loss
(
dtype
,
num_tokens
,
num_experts
,
topk
,
score_function
):
if
score_function
==
"sigmoid"
:
# Construct the special logits to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
logits
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
logits
=
logits
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
else
:
logits
=
torch
.
arange
(
num_tokens
*
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-4
logits
=
logits
.
view
(
num_tokens
,
num_experts
)
logits
.
requires_grad
=
True
logits_clone
=
deepcopy
(
logits
)
logits_clone
.
requires_grad
=
True
routing_map
,
scores
=
compute_scores_for_aux_loss_pytorch
(
logits
=
logits
,
topk
=
topk
,
score_function
=
score_function
,
)
routing_map_fused
,
scores_fused
=
fused_compute_score_for_moe_aux_loss
(
logits
=
logits_clone
,
topk
=
topk
,
score_function
=
score_function
,
)
torch
.
testing
.
assert_close
(
scores
,
scores_fused
)
torch
.
testing
.
assert_close
(
routing_map
,
routing_map_fused
)
loss
=
torch
.
sum
(
scores
)
loss
.
backward
()
loss_fused
=
torch
.
sum
(
scores_fused
)
loss_fused
.
backward
()
torch
.
testing
.
assert_close
(
logits
.
grad
,
logits_clone
.
grad
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
,
7168
,
14234
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
256
,
128
,
32
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
])
def
test_fused_moe_aux_loss
(
dtype
,
num_tokens
,
num_experts
,
topk
):
# Construct the special probs to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
probs
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
probs
=
probs
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
probs
=
probs
.
view
(
num_tokens
,
num_experts
)
probs
.
requires_grad
=
True
tokens_per_expert
=
torch
.
randint
(
1
,
1000
,
(
num_experts
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
coeff
=
0.01
probs_clone
=
deepcopy
(
probs
)
probs_clone
.
requires_grad
=
True
aux_loss
=
aux_loss_pytorch
(
probs
=
probs
,
tokens_per_expert
=
tokens_per_expert
,
total_num_tokens
=
num_tokens
,
topk
=
topk
,
num_experts
=
num_experts
,
moe_aux_loss_coeff
=
coeff
,
)
aux_loss_fused
=
fused_moe_aux_loss
(
probs
=
probs_clone
,
tokens_per_expert
=
tokens_per_expert
,
total_num_tokens
=
num_tokens
,
num_experts
=
num_experts
,
topk
=
topk
,
coeff
=
coeff
,
)
torch
.
testing
.
assert_close
(
aux_loss
,
aux_loss_fused
)
# Backward
aux_loss
.
backward
()
aux_loss_fused
.
backward
()
torch
.
testing
.
assert_close
(
probs
.
grad
,
probs_clone
.
grad
)
def
profile_topk_softmax
(
dtype
,
num_tokens
,
num_experts
,
topk
,
enable_bias
,
use_pre_softmax
,
):
group_topk
=
4
scaling_factor
=
1.2
test_topk_sigmoid
(
torch
.
float32
,
num_tokens
,
num_experts
,
topk
,
group_topk
,
scaling_factor
,
enable_bias
)
test_topk_softmax
(
torch
.
float32
,
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
group_topk
,
scaling_factor
)
if
__name__
==
"__main__"
:
test_fused_scores_for_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2
,
num_experts
=
32
,
topk
=
8
,
score_function
=
"softmax"
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
256
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
256
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
256
,
topk
=
4
)
tests/pytorch/test_fusible_ops.py
View file @
44740c6c
...
...
@@ -20,8 +20,8 @@ import transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
from
transformer_engine.pytorch.ops.fused
import
(
BackwardBiasActivation
,
BackwardLinearAdd
,
ForwardLinearBiasActivation
,
ForwardLinearBiasAdd
,
...
...
@@ -162,7 +162,7 @@ def make_reference_and_test_tensors(
return
ref
,
test
class
TestSequential
:
class
TestSequential
Container
:
"""Tests for sequential container"""
def
test_modules
(
self
)
->
None
:
...
...
@@ -1878,6 +1878,98 @@ class TestFusedOps:
db_test
=
model
[
0
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
(
"relu"
,
"gelu"
))
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
32
,
32
),
(
32
,
1
,
32
),
(
8
,
2
,
2
,
32
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_backward_bias_activation
(
self
,
*
,
activation
:
str
,
out_shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
)
->
None
:
"""Backward dbias + dact + quantize"""
# Tensor dimensions
in_shape
=
list
(
out_shape
)
hidden_size
=
in_shape
[
-
1
]
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
device
=
device
)
if
quantization
==
"mxfp8"
and
(
len
(
in_shape
)
<
2
or
in_shape
[
-
1
]
%
32
!=
0
):
pytest
.
skip
(
"Unsupported tensor size for MXFP8"
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
b_ref
,
b_test
=
make_reference_and_test_tensors
(
hidden_size
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
x_ref
+
b_ref
.
reshape
([
1
]
*
(
len
(
in_shape
)
-
1
)
+
[
hidden_size
])
if
activation
==
"gelu"
:
y_ref
=
torch
.
nn
.
functional
.
gelu
(
y_ref
,
approximate
=
"tanh"
)
elif
activation
==
"relu"
:
y_ref
=
torch
.
nn
.
functional
.
relu
(
y_ref
)
else
:
raise
ValueError
(
f
"Unexpected activation function (
{
activation
}
)"
)
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
act_type
=
te_ops
.
GELU
if
activation
==
"gelu"
else
te_ops
.
ReLU
model
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
False
,
backward
=
True
),
te_ops
.
Bias
(
hidden_size
,
device
=
device
,
dtype
=
dtype
),
act_type
(),
)
with
torch
.
no_grad
():
model
[
1
].
bias
.
copy_
(
b_test
)
del
b_test
with
te
.
fp8_autocast
(
enabled
=
with_quantization
,
fp8_recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check that backward operations have been fused
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
if
with_quantization
and
quantization
in
[
"fp8_delayed_scaling"
,
"mxfp8"
]:
assert
len
(
backward_ops
)
==
2
assert
isinstance
(
backward_ops
[
0
][
0
],
BackwardBiasActivation
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Quantize
)
else
:
assert
len
(
backward_ops
)
==
3
assert
isinstance
(
backward_ops
[
0
][
0
],
act_type
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Bias
)
assert
isinstance
(
backward_ops
[
2
][
0
],
te_ops
.
Quantize
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
with_quantization
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
db_test
=
model
[
1
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_backward_linear_add
(
...
...
@@ -2093,3 +2185,109 @@ class TestCheckpointing:
torch
.
testing
.
assert_close
(
y_load
,
y_save
,
**
tols
)
for
x_load
,
x_save
in
zip
(
xs_load
,
xs_save
):
torch
.
testing
.
assert_close
(
x_load
.
grad
,
x_save
.
grad
,
**
tols
)
class
TestSequentialModules
:
"""Test for larger Sequentials with modules commonly used together"""
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"normalization"
,
(
"LayerNorm"
,
"RMSNorm"
))
@
pytest
.
mark
.
parametrize
(
"quantized_compute"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_layernorm_mlp
(
self
,
*
,
bias
:
bool
,
normalization
:
str
,
quantized_compute
:
bool
,
quantized_weight
:
bool
,
dtype
:
torch
.
dtype
,
quantization
:
Optional
[
str
],
device
:
torch
.
device
=
"cuda"
,
hidden_size
:
int
=
32
,
sequence_length
:
int
=
512
,
batch_size
:
int
=
4
,
ffn_hidden_size
:
int
=
64
,
layernorm_epsilon
:
float
=
1e-5
,
)
->
None
:
"""
LayerNorm/RMSNorm + Linear + GELU + Linear
Note that this test checks only if the module runs
as when chaining multiple modules it is hard to validate
numerical accuracy.
"""
# Make input shape
in_shape
=
(
sequence_length
,
batch_size
,
hidden_size
)
ffn_shape
=
in_shape
[:
-
1
]
+
(
ffn_hidden_size
,)
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
ffn_shape
,
device
=
device
)
quantization_needed
=
quantized_compute
or
quantized_weight
if
quantization
is
None
and
quantization_needed
:
pytest
.
skip
(
"Quantization scheme is not specified"
)
if
quantization
is
not
None
and
not
quantization_needed
:
pytest
.
skip
(
"Quantization scheme is not used"
)
# Random data
_
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
_
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
if
normalization
==
"LayerNorm"
:
norm
=
te_ops
.
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
,
device
=
device
,
dtype
=
dtype
,
)
else
:
norm
=
te_ops
.
RMSNorm
(
hidden_size
,
eps
=
layernorm_epsilon
,
device
=
device
,
dtype
=
dtype
,
)
ffn1
=
te_ops
.
Linear
(
hidden_size
,
ffn_hidden_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
act
=
te_ops
.
GELU
()
ffn2
=
te_ops
.
Linear
(
ffn_hidden_size
,
hidden_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
forward
=
te_ops
.
Sequential
(
norm
,
ffn1
,
act
,
ffn2
)
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment