Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
44740c6c
Commit
44740c6c
authored
Jul 22, 2025
by
yuguo
Browse files
Merge commit '
7a9a0825
' of...
Merge commit '
7a9a0825
' of
https://github.com/NVIDIA/TransformerEngine
parents
8113d9e0
7a9a0825
Changes
162
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1463 additions
and
276 deletions
+1463
-276
tests/cpp/operator/test_multi_unpadding.cu
tests/cpp/operator/test_multi_unpadding.cu
+186
-0
tests/cpp/operator/test_normalization.cu
tests/cpp/operator/test_normalization.cu
+2
-2
tests/cpp/util/test_string.cpp
tests/cpp/util/test_string.cpp
+1
-1
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+118
-91
tests/jax/test_distributed_layernorm.py
tests/jax/test_distributed_layernorm.py
+0
-2
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+228
-105
tests/jax/test_helper.py
tests/jax/test_helper.py
+2
-2
tests/jax/utils.py
tests/jax/utils.py
+23
-6
tests/pytorch/debug/run_distributed.py
tests/pytorch/debug/run_distributed.py
+46
-27
tests/pytorch/debug/test_distributed.py
tests/pytorch/debug/test_distributed.py
+2
-4
tests/pytorch/debug/test_numerics.py
tests/pytorch/debug/test_numerics.py
+45
-2
tests/pytorch/debug/test_sanity.py
tests/pytorch/debug/test_sanity.py
+5
-13
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+6
-7
tests/pytorch/distributed/test_fusible_ops.py
tests/pytorch/distributed/test_fusible_ops.py
+0
-1
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
.../pytorch/distributed/test_fusible_ops_with_userbuffers.py
+2
-2
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
+26
-9
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
+2
-0
tests/pytorch/test_checkpoint.py
tests/pytorch/test_checkpoint.py
+175
-0
tests/pytorch/test_fused_router.py
tests/pytorch/test_fused_router.py
+394
-0
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+200
-2
No files found.
tests/cpp/operator/test_multi_unpadding.cu
0 → 100644
View file @
44740c6c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <string>
#include <vector>
#include <cstdio>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/padding.h>
#include "../test_common.h"
using
namespace
transformer_engine
;
namespace
{
template
<
typename
InputType
,
typename
OutputType
>
void
compute_unpadding_ref
(
const
std
::
vector
<
std
::
vector
<
InputType
>>&
input_list
,
std
::
vector
<
std
::
vector
<
OutputType
>>&
output_list
,
const
std
::
vector
<
size_t
>&
height_list
,
const
std
::
vector
<
size_t
>&
width_list
,
const
std
::
vector
<
int
>&
padded_height_list
)
{
using
compute_t
=
float
;
for
(
size_t
tensor_id
=
0
;
tensor_id
<
input_list
.
size
();
++
tensor_id
)
{
const
auto
&
input
=
input_list
[
tensor_id
];
auto
&
output
=
output_list
[
tensor_id
];
const
size_t
height
=
height_list
[
tensor_id
];
const
size_t
width
=
width_list
[
tensor_id
];
const
size_t
padded_height
=
padded_height_list
[
tensor_id
];
// Only copy the valid (unpadded) portion
for
(
size_t
i
=
0
;
i
<
height
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
width
;
++
j
)
{
const
compute_t
x
=
static_cast
<
compute_t
>
(
input
[
i
*
width
+
j
]);
const
OutputType
y
=
static_cast
<
OutputType
>
(
x
);
output
[
i
*
width
+
j
]
=
y
;
}
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
performUnpaddingTest
()
{
using
namespace
test
;
const
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
const
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
const
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>>
tensor_dims
=
{{
1
,
1
},
{
1
,
768
},
{
768
,
1
},
{
768
,
768
},
{
43
,
43
},
{
43
,
256
},
{
256
,
43
},
{
256
,
256
}};
const
size_t
num_tensors
=
tensor_dims
.
size
();
constexpr
int
align
=
16
;
// Buffers for Transformer Engine implementation
std
::
vector
<
Tensor
>
padded_input_list
,
unpadded_output_list
;
// Buffers for reference implementation
std
::
vector
<
std
::
vector
<
InputType
>>
ref_padded_input_list
;
std
::
vector
<
std
::
vector
<
OutputType
>>
ref_unpadded_output_list
;
std
::
vector
<
size_t
>
ref_height_list
(
num_tensors
),
ref_width_list
(
num_tensors
);
std
::
vector
<
int
>
ref_padded_height_list
(
num_tensors
);
// Initialize buffers
for
(
size_t
tensor_id
=
0
;
tensor_id
<
num_tensors
;
++
tensor_id
)
{
const
size_t
original_height
=
tensor_dims
[
tensor_id
].
first
;
const
size_t
width
=
tensor_dims
[
tensor_id
].
second
;
const
size_t
padded_height
=
(
original_height
+
align
-
1
)
/
align
*
align
;
// Input is padded tensor (padded_height x width)
padded_input_list
.
emplace_back
(
Tensor
(
"padded_input_"
+
std
::
to_string
(
tensor_id
),
std
::
vector
<
size_t
>
{
padded_height
,
width
},
itype
));
// Output is unpadded tensor (original_height x width)
unpadded_output_list
.
emplace_back
(
Tensor
(
"unpadded_output_"
+
std
::
to_string
(
tensor_id
),
std
::
vector
<
size_t
>
{
original_height
,
width
},
otype
));
auto
&
padded_input
=
padded_input_list
.
back
();
auto
&
unpadded_output
=
unpadded_output_list
.
back
();
// Fill padded input with random data (including padding area)
fillUniform
(
&
padded_input
);
setRandomScale
(
&
unpadded_output
);
// Initialize reference buffers
ref_padded_input_list
.
emplace_back
(
padded_height
*
width
);
ref_unpadded_output_list
.
emplace_back
(
original_height
*
width
);
// Copy data to reference buffers
std
::
copy
(
padded_input
.
rowwise_cpu_dptr
<
InputType
>
(),
padded_input
.
rowwise_cpu_dptr
<
InputType
>
()
+
padded_height
*
width
,
ref_padded_input_list
.
back
().
begin
());
ref_height_list
[
tensor_id
]
=
original_height
;
ref_width_list
[
tensor_id
]
=
width
;
ref_padded_height_list
[
tensor_id
]
=
padded_height
;
}
// Transformer Engine implementation
auto
make_nvte_vector
=
[](
std
::
vector
<
Tensor
>&
tensor_list
)
->
std
::
vector
<
NVTETensor
>
{
std
::
vector
<
NVTETensor
>
nvte_tensor_list
;
for
(
auto
&
tensor
:
tensor_list
)
{
nvte_tensor_list
.
emplace_back
(
tensor
.
data
());
}
return
nvte_tensor_list
;
};
// Convert height_list to int for the API
std
::
vector
<
int
>
original_height_list_int
(
num_tensors
);
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
original_height_list_int
[
i
]
=
static_cast
<
int
>
(
ref_height_list
[
i
]);
}
// Call unpadding API
nvte_multi_unpadding
(
num_tensors
,
make_nvte_vector
(
padded_input_list
).
data
(),
make_nvte_vector
(
unpadded_output_list
).
data
(),
original_height_list_int
.
data
(),
0
);
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
ASSERT_EQ
(
err
,
cudaSuccess
)
<<
cudaGetErrorString
(
err
);
// Reference implementation
compute_unpadding_ref
<
InputType
,
OutputType
>
(
ref_padded_input_list
,
ref_unpadded_output_list
,
ref_height_list
,
ref_width_list
,
ref_padded_height_list
);
// Check correctness
for
(
size_t
tensor_id
=
0
;
tensor_id
<
num_tensors
;
++
tensor_id
)
{
auto
[
atol
,
rtol
]
=
getTolerances
(
otype
);
compareResults
(
"unpadded_output"
,
unpadded_output_list
[
tensor_id
],
ref_unpadded_output_list
[
tensor_id
].
data
(),
true
,
atol
,
rtol
);
}
}
}
// namespace
class
MultiUnpaddingTestSuite
:
public
::
testing
::
TestWithParam
<
transformer_engine
::
DType
>
{};
TEST_P
(
MultiUnpaddingTestSuite
,
TestMultiUnpadding
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
DType
input_type
=
GetParam
();
const
DType
output_type
=
input_type
;
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
output_type
,
OutputType
,
performUnpaddingTest
<
InputType
,
OutputType
>
();
);
);
}
INSTANTIATE_TEST_SUITE_P
(
OperatorTest
,
MultiUnpaddingTestSuite
,
::
testing
::
ValuesIn
(
test
::
all_fp_types
),
[](
const
testing
::
TestParamInfo
<
MultiUnpaddingTestSuite
::
ParamType
>&
info
)
{
std
::
string
name
=
test
::
typeName
(
info
.
param
);
return
name
;
});
tests/cpp/operator/test_normalization.cu
View file @
44740c6c
...
@@ -34,8 +34,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
...
@@ -34,8 +34,8 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return
;
return
;
}
}
if
(
getDeviceComputeCapability
()
<
blackwell
ComputeCapability
&&
use_cudnn
)
{
if
(
getDeviceComputeCapability
()
<
hopper
ComputeCapability
&&
use_cudnn
)
{
GTEST_SKIP
()
<<
"cuDNN normalizations not supported on pre-
Blackwell
GPUs yet!"
;
GTEST_SKIP
()
<<
"cuDNN normalizations not supported on pre-
Hopper
GPUs yet!"
;
}
}
using
WeightType
=
InputType
;
using
WeightType
=
InputType
;
...
...
tests/cpp/util/test_string.cpp
View file @
44740c6c
...
@@ -38,7 +38,7 @@ TEST(UtilTest, ToStringLike) { // to_string_like
...
@@ -38,7 +38,7 @@ TEST(UtilTest, ToStringLike) { // to_string_like
// Non-zero integer types
// Non-zero integer types
EXPECT_EQ
(
to_string_like
(
static_cast
<
char
>
(
1
)),
"1"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
char
>
(
1
)),
"1"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
char
>
(
-
1
)),
"-1"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
signed
char
>
(
-
1
)),
"-1"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
unsigned
char
>
(
2
)),
"2"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
unsigned
char
>
(
2
)),
"2"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
short
>
(
3
)),
"3"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
short
>
(
3
)),
"3"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
short
>
(
-
5
)),
"-5"
);
EXPECT_EQ
(
to_string_like
(
static_cast
<
short
>
(
-
5
)),
"-5"
);
...
...
tests/jax/test_custom_call_compute.py
View file @
44740c6c
...
@@ -13,6 +13,7 @@ import operator
...
@@ -13,6 +13,7 @@ import operator
from
utils
import
(
from
utils
import
(
assert_allclose
,
assert_allclose
,
pytest_parametrize_wrapper
,
pytest_parametrize_wrapper
,
use_jax_gemm
,
)
)
from
transformer_engine.jax.layernorm
import
layernorm
from
transformer_engine.jax.layernorm
import
layernorm
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
from
transformer_engine.jax.layernorm_mlp
import
layernorm_mlp
...
@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import (
...
@@ -30,7 +31,6 @@ from transformer_engine.jax.cpp_extensions.quantization import (
from
transformer_engine.jax.cpp_extensions.misc
import
get_cudnn_version
from
transformer_engine.jax.cpp_extensions.misc
import
get_cudnn_version
from
transformer_engine.jax
import
cpp_extensions
as
tex
from
transformer_engine.jax
import
cpp_extensions
as
tex
from
transformer_engine.jax.quantize
import
(
from
transformer_engine.jax.quantize
import
(
DelayedScaleQuantizer
,
ScaledTensor
,
ScaledTensor
,
ScaledTensor1x
,
ScaledTensor1x
,
ScaledTensor2x
,
ScaledTensor2x
,
...
@@ -109,8 +109,8 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
...
@@ -109,8 +109,8 @@ def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
else
:
else
:
assert_allclose
(
a
.
dequantize
(),
b
,
dtype
=
a
.
data
.
dtype
)
assert_allclose
(
a
.
dequantize
(),
b
,
dtype
=
a
.
data
.
dtype
)
elif
isinstance
(
a
,
ScaledTensor2x
):
elif
isinstance
(
a
,
ScaledTensor2x
):
assert_dequantized_scaled_tensor
(
a
.
get_
rowwise_tensor
()
,
b
)
assert_dequantized_scaled_tensor
(
a
.
rowwise_tensor
,
b
)
assert_dequantized_scaled_tensor
(
a
.
get_
colwise_tensor
()
,
b
)
assert_dequantized_scaled_tensor
(
a
.
colwise_tensor
,
b
)
else
:
else
:
pytest
.
fail
(
"a must be a ScaledTensor object"
)
pytest
.
fail
(
"a must be a ScaledTensor object"
)
...
@@ -139,10 +139,10 @@ def assert_dequantized_grouped_scaled_tensor(
...
@@ -139,10 +139,10 @@ def assert_dequantized_grouped_scaled_tensor(
dq_a_i
=
dq_a_i
.
reshape
(
b_i
.
shape
)
dq_a_i
=
dq_a_i
.
reshape
(
b_i
.
shape
)
assert_allclose
(
dq_a_i
,
b_i
,
dtype
=
a
.
data
.
dtype
)
assert_allclose
(
dq_a_i
,
b_i
,
dtype
=
a
.
data
.
dtype
)
elif
isinstance
(
a
,
ScaledTensor2x
):
elif
isinstance
(
a
,
ScaledTensor2x
):
assert
isinstance
(
a
.
get_
rowwise_tensor
()
,
GroupedScaledTensor1x
)
assert
isinstance
(
a
.
rowwise_tensor
,
GroupedScaledTensor1x
)
assert
isinstance
(
a
.
get_
colwise_tensor
()
,
GroupedScaledTensor1x
)
assert
isinstance
(
a
.
colwise_tensor
,
GroupedScaledTensor1x
)
assert_dequantized_grouped_scaled_tensor
(
a
.
get_
rowwise_tensor
()
,
b
)
assert_dequantized_grouped_scaled_tensor
(
a
.
rowwise_tensor
,
b
)
assert_dequantized_grouped_scaled_tensor
(
a
.
get_
colwise_tensor
()
,
b
)
assert_dequantized_grouped_scaled_tensor
(
a
.
colwise_tensor
,
b
)
else
:
else
:
pytest
.
fail
(
"a must be a GroupedScaledTensor object"
)
pytest
.
fail
(
"a must be a GroupedScaledTensor object"
)
...
@@ -851,6 +851,22 @@ class TestFusedQuantize:
...
@@ -851,6 +851,22 @@ class TestFusedQuantize:
)
)
valid_fp8_gemm_operand_types
=
[
(
jnp
.
float8_e4m3fn
,
jnp
.
float8_e4m3fn
),
(
jnp
.
float8_e5m2
,
jnp
.
float8_e4m3fn
),
(
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
),
]
def
_use_jax_fp8_gemm
(
enabled
=
False
):
import
os
if
enabled
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS_RE"
]
=
"^(?!GemmPrimitive$).+$"
elif
"NVTE_JAX_CUSTOM_CALLS_RE"
in
os
.
environ
:
os
.
environ
.
pop
(
"NVTE_JAX_CUSTOM_CALLS_RE"
)
class
TestDense
:
class
TestDense
:
def
_ref_gemm_with_jnp_dot
(
self
,
a
,
b
,
data_layout
):
def
_ref_gemm_with_jnp_dot
(
self
,
a
,
b
,
data_layout
):
if
data_layout
[
0
]
==
"T"
:
if
data_layout
[
0
]
==
"T"
:
...
@@ -883,27 +899,47 @@ class TestDense:
...
@@ -883,27 +899,47 @@ class TestDense:
def
test_gemm_bf16
(
self
,
m
,
n
,
k
,
data_layout
):
def
test_gemm_bf16
(
self
,
m
,
n
,
k
,
data_layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_layout
)
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_layout
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
=
contracting_dims
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_layout
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"
q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
]
)
@
pytest_parametrize_wrapper
(
"
x_qtype,w_qtype"
,
valid_fp8_gemm_operand_types
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"data_layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
@
pytest_parametrize_wrapper
(
"data_layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
data_layout
):
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_gemm_fp8
(
self
,
m
,
n
,
k
,
x_qtype
,
w_qtype
,
scaling_mode
,
data_layout
,
with_jax_gemm
):
if
(
not
with_jax_gemm
and
scaling_mode
.
is_1d_block_scaling
()
and
jnp
.
float8_e5m2
in
(
x_qtype
,
w_qtype
)
):
pytest
.
skip
(
"Float8E5M2 is not recommended for MXFP8 GEMM."
)
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_layout
)
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_layout
)
quantizer_set
=
QuantizerFactory
.
create_set
(
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
False
scaling_mode
=
scaling_mode
,
fwd_dtype
=
jnp
.
float8_e4m3fn
,
bwd_dtype
=
jnp
.
float8_e5m2
,
is_2x2x
=
False
,
)
)
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
primitive_out
=
tex
.
gemm
(
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
=
contracting_dims
,
quantizer_set
=
quantizer_set
x
,
w
,
contracting_dims
=
contracting_dims
,
lhs_quantizer
=
(
quantizer_set
.
x
if
x_qtype
==
jnp
.
float8_e4m3fn
else
quantizer_set
.
dgrad
),
rhs_quantizer
=
(
quantizer_set
.
kernel
if
w_qtype
==
jnp
.
float8_e4m3fn
else
quantizer_set
.
dgrad
),
)
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_layout
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
float8_e4m3fn
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
def
test_dense_grad_bf16
(
self
,
m
,
n
,
k
):
def
test_dense_grad_bf16
(
self
,
m
,
n
,
k
):
...
@@ -932,9 +968,9 @@ class TestDense:
...
@@ -932,9 +968,9 @@ class TestDense:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
def
test_dense_grad_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
):
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_dense_grad_fp8
(
self
,
m
,
n
,
k
,
scaling_mode
,
with_jax_gemm
):
data_layout
=
"NN"
data_layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_layout
)
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_layout
)
...
@@ -956,10 +992,14 @@ class TestDense:
...
@@ -956,10 +992,14 @@ class TestDense:
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
(
0
,
1
,
2
))
quantizer_set
=
QuantizerFactory
.
create_set
(
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
True
scaling_mode
=
scaling_mode
,
fwd_dtype
=
jnp
.
float8_e4m3fn
,
bwd_dtype
=
jnp
.
float8_e5m2
if
scaling_mode
.
is_tensor_scaling
()
else
jnp
.
float8_e4m3fn
,
is_2x2x
=
True
,
)
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
for
_
in
range
(
n_iterations
):
for
_
in
range
(
n_iterations
):
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
,
primitive_bias_grad
)
=
(
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
,
primitive_bias_grad
)
=
(
value_n_grad_primitive_func
(
x
,
w
,
bias
,
contracting_dims
,
quantizer_set
)
value_n_grad_primitive_func
(
x
,
w
,
bias
,
contracting_dims
,
quantizer_set
)
...
@@ -969,10 +1009,10 @@ class TestDense:
...
@@ -969,10 +1009,10 @@ class TestDense:
x
,
w
,
bias
,
data_layout
x
,
w
,
bias
,
data_layout
)
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
float8_e4m3fn
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
primitive_bias_grad
,
ref_bias_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_bias_grad
,
ref_bias_grad
,
dtype
=
jnp
.
float8_e5m2
)
@
pytest
.
fixture
(
name
=
"random_inputs"
)
@
pytest
.
fixture
(
name
=
"random_inputs"
)
...
@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
...
@@ -996,20 +1036,13 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class
TestFusedDense
:
class
TestFusedDense
:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
def
test_layernorm_dense_grad
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
norm_type
):
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_dense_grad
(
self
,
m
,
n
,
k
,
scaling_mode
,
norm_type
,
with_jax_gemm
):
"""
"""
Test layernorm_dense VJP Rule
Test layernorm_dense VJP Rule
"""
"""
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
in
(
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
,
):
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
# zero_centered_gamma is already tested in TestNorm
zero_centered_gamma
=
False
zero_centered_gamma
=
False
eps
=
1e-6
eps
=
1e-6
...
@@ -1025,8 +1058,8 @@ class TestFusedDense:
...
@@ -1025,8 +1058,8 @@ class TestFusedDense:
quantizer_set
=
QuantizerFactory
.
create_set
(
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
fwd_dtype
=
jnp
.
float8_e4m3fn
,
bwd_dtype
=
q_dtype
,
bwd_dtype
=
jnp
.
float8_e5m2
if
scaling_mode
.
is_tensor_scaling
()
else
jnp
.
float8_e4m3fn
,
is_2x2x
=
True
,
is_2x2x
=
True
,
)
)
...
@@ -1064,6 +1097,7 @@ class TestFusedDense:
...
@@ -1064,6 +1097,7 @@ class TestFusedDense:
)
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
for
_
in
range
(
n_iterations
):
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_out
,
(
prim_x_grad
,
prim_x_grad
,
...
@@ -1072,33 +1106,26 @@ class TestFusedDense:
...
@@ -1072,33 +1106,26 @@ class TestFusedDense:
prim_beta_grad
,
prim_beta_grad
,
)
=
value_n_grad_prim_func
(
x
,
w
,
gamma
,
beta
)
)
=
value_n_grad_prim_func
(
x
,
w
,
gamma
,
beta
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
jnp
.
float8_e4m3fn
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
prim_w_grad
,
ref_w_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_w_grad
,
ref_w_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
jnp
.
float8_e5m2
)
if
beta
is
not
None
:
if
beta
is
not
None
:
assert_allclose
(
prim_beta_grad
,
ref_beta_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_beta_grad
,
ref_beta_grad
,
dtype
=
jnp
.
float8_e5m2
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
fp8_unsupported_reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_grad
(
def
test_layernorm_mlp_grad
(
self
,
m
,
n
,
k
,
activation_type
,
q_dtype
,
scaling_mode
,
norm_type
,
use_bias
self
,
m
,
n
,
k
,
activation_type
,
scaling_mode
,
norm_type
,
use_bias
,
with_jax_gemm
):
):
"""
"""
Test layernorm_mlp VJP Rule
Test layernorm_mlp VJP Rule
"""
"""
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
in
(
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
,
):
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
# zero_centered_gamma is already tested in TestNorm
zero_centered_gamma
=
False
zero_centered_gamma
=
False
eps
=
1e-6
eps
=
1e-6
...
@@ -1123,8 +1150,8 @@ class TestFusedDense:
...
@@ -1123,8 +1150,8 @@ class TestFusedDense:
quantizer_sets
=
QuantizerFactory
.
create_set
(
quantizer_sets
=
QuantizerFactory
.
create_set
(
n_quantizer_sets
=
2
,
n_quantizer_sets
=
2
,
scaling_mode
=
scaling_mode
,
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
fwd_dtype
=
jnp
.
float8_e4m3fn
,
bwd_dtype
=
q_dtype
,
bwd_dtype
=
jnp
.
float8_e5m2
if
scaling_mode
.
is_tensor_scaling
()
else
jnp
.
float8_e4m3fn
,
is_2x2x
=
True
,
is_2x2x
=
True
,
)
)
...
@@ -1153,14 +1180,13 @@ class TestFusedDense:
...
@@ -1153,14 +1180,13 @@ class TestFusedDense:
ln_out
=
_ref_jax_norm_impl
(
ln_out
=
_ref_jax_norm_impl
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
=
None
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
eps
,
quantizer
=
None
)
)
# TODO: replace gemm with jnp.dot
linear_1_out
=
jax
.
lax
.
dot_general
(
ln_out
,
kernel_1
,
(((
1
,),
(
0
,)),
((),
())))
linear_1_out
=
tex
.
gemm
(
ln_out
,
kernel_1
,
((
1
,),
(
0
,)))
if
use_bias
:
if
use_bias
:
bias_1_shape
=
(
1
,)
*
(
linear_1_out
.
ndim
-
bias_1
.
ndim
)
+
bias_1
.
shape
bias_1_shape
=
(
1
,)
*
(
linear_1_out
.
ndim
-
bias_1
.
ndim
)
+
bias_1
.
shape
linear_1_out
+=
jnp
.
reshape
(
bias_1
,
bias_1_shape
)
linear_1_out
+=
jnp
.
reshape
(
bias_1
,
bias_1_shape
)
x
=
_jax_act_lu
(
linear_1_out
,
activation_type
)
x
=
_jax_act_lu
(
linear_1_out
,
activation_type
)
linear_2_out
=
tex
.
gemm
(
x
,
kernel_2
,
((
1
,),
(
0
,)))
linear_2_out
=
jax
.
lax
.
dot_general
(
x
,
kernel_2
,
((
(
1
,),
(
0
,))
,
((),
()))
)
if
use_bias
:
if
use_bias
:
bias_2_shape
=
(
1
,)
*
(
linear_2_out
.
ndim
-
bias_2
.
ndim
)
+
bias_2
.
shape
bias_2_shape
=
(
1
,)
*
(
linear_2_out
.
ndim
-
bias_2
.
ndim
)
+
bias_2
.
shape
linear_2_out
+=
jnp
.
reshape
(
bias_2
,
bias_2_shape
)
linear_2_out
+=
jnp
.
reshape
(
bias_2
,
bias_2_shape
)
...
@@ -1174,6 +1200,7 @@ class TestFusedDense:
...
@@ -1174,6 +1200,7 @@ class TestFusedDense:
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
range
(
6
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
range
(
6
))
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
for
_
in
range
(
n_iterations
):
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_out
,
(
prim_x_grad
,
prim_x_grad
,
...
@@ -1193,18 +1220,18 @@ class TestFusedDense:
...
@@ -1193,18 +1220,18 @@ class TestFusedDense:
ref_bias_2_grad
,
ref_bias_2_grad
,
)
=
value_n_grad_ref_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
)
)
=
value_n_grad_ref_func
(
x
,
gamma
,
kernel_1
,
kernel_2
,
bias_1
,
bias_2
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
prim_out
,
ref_out
,
dtype
=
jnp
.
float8_e4m3fn
)
assert_allclose
(
prim_kernel_2_grad
,
ref_kernel_2_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_kernel_2_grad
,
ref_kernel_2_grad
,
dtype
=
jnp
.
float8_e5m2
)
if
use_bias
:
if
use_bias
:
assert_allclose
(
prim_bias_2_grad
,
ref_bias_2_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_bias_2_grad
,
ref_bias_2_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
prim_kernel_1_grad
,
ref_kernel_1_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_kernel_1_grad
,
ref_kernel_1_grad
,
dtype
=
jnp
.
float8_e5m2
)
if
use_bias
:
if
use_bias
:
assert_allclose
(
prim_bias_1_grad
,
ref_bias_1_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_bias_1_grad
,
ref_bias_1_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_gamma_grad
,
ref_gamma_grad
,
dtype
=
jnp
.
float8_e5m2
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
float8_e5m2
)
# E5M2 * E5M2 is not supported
# E5M2 * E5M2 is not supported
...
@@ -1238,7 +1265,9 @@ class TestGroupedDense:
...
@@ -1238,7 +1265,9 @@ class TestGroupedDense:
ref_out
=
[]
ref_out
=
[]
dim_num
=
(
contracting_dims
,
((),
()))
dim_num
=
(
contracting_dims
,
((),
()))
for
lhs_i
,
rhs_i
,
bias_i
in
zip
(
lhs
,
rhs
,
bias
):
for
lhs_i
,
rhs_i
,
bias_i
in
zip
(
lhs
,
rhs
,
bias
):
out_i
=
jax
.
lax
.
dot_general
(
lhs_i
,
rhs_i
,
dim_num
)
+
jnp
.
expand_dims
(
bias_i
,
axis
=
0
)
out_i
=
jax
.
lax
.
dot_general
(
lhs_i
,
rhs_i
,
dim_num
,
precision
=
jax
.
lax
.
Precision
.
HIGHEST
)
+
jnp
.
expand_dims
(
bias_i
,
axis
=
0
)
ref_out
.
append
(
jnp
.
squeeze
(
out_i
))
ref_out
.
append
(
jnp
.
squeeze
(
out_i
))
return
ref_out
return
ref_out
...
@@ -1250,6 +1279,9 @@ class TestGroupedDense:
...
@@ -1250,6 +1279,9 @@ class TestGroupedDense:
group_sizes
=
jnp
.
sort
(
jax
.
random
.
randint
(
subkeys
[
0
],
(
n_groups
-
1
,),
0
,
m
))
group_sizes
=
jnp
.
sort
(
jax
.
random
.
randint
(
subkeys
[
0
],
(
n_groups
-
1
,),
0
,
m
))
group_sizes
=
jnp
.
concatenate
([
jnp
.
array
([
0
]),
group_sizes
,
jnp
.
array
([
m
])])
group_sizes
=
jnp
.
concatenate
([
jnp
.
array
([
0
]),
group_sizes
,
jnp
.
array
([
m
])])
group_sizes
=
jnp
.
diff
(
group_sizes
)
group_sizes
=
jnp
.
diff
(
group_sizes
)
# Make one empty input lhs to test empty GEMM handling
group_sizes
=
group_sizes
.
at
[
0
].
set
(
group_sizes
[
0
]
+
group_sizes
[
1
])
group_sizes
=
group_sizes
.
at
[
1
].
set
(
0
)
assert
group_sizes
.
sum
()
==
m
assert
group_sizes
.
sum
()
==
m
# *32 to make sure that input shape works for MXFP8
# *32 to make sure that input shape works for MXFP8
...
@@ -1301,9 +1333,6 @@ class TestGroupedDense:
...
@@ -1301,9 +1333,6 @@ class TestGroupedDense:
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"layout"
,
[
"NN"
])
@
pytest_parametrize_wrapper
(
"layout"
,
[
"NN"
])
def
test_grouped_gemm_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
input_shape
,
layout
):
def
test_grouped_gemm_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
input_shape
,
layout
):
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
pytest
.
skip
(
"MXFP8 is not supported in grouped_gemm yet"
)
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
quantizer_set
=
QuantizerFactory
.
create_set
(
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
scaling_mode
=
scaling_mode
,
...
@@ -1343,9 +1372,10 @@ class TestGroupedDense:
...
@@ -1343,9 +1372,10 @@ class TestGroupedDense:
def
_ref_sum_grouped_dense
(
self
,
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
):
def
_ref_sum_grouped_dense
(
self
,
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
):
out_list
=
self
.
_ref_grouped_dense
(
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
)
out_list
=
self
.
_ref_grouped_dense
(
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
)
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# Note: we use jnp.sum instead of jnp.mean to make the gradient larger
# and prevent them from being clamp to zero
# and prevent them from being clamp to zero in FP8. / sqrt(x.size) is used to
# normalize the output and prevent the gradient from being too large for FP8.
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
out_sum_list
=
[
jnp
.
sum
(
out
)
for
out
in
out_list
]
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
return
jnp
.
sum
(
jnp
.
asarray
(
out_sum_list
))
/
jnp
.
sqrt
(
x
.
size
)
def
_primitive_sum_grouped_dense
(
def
_primitive_sum_grouped_dense
(
self
,
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
,
quantizer_set
=
noop_quantizer_set
self
,
x
,
kernel
,
bias
,
group_sizes
,
contracting_dims
,
quantizer_set
=
noop_quantizer_set
...
@@ -1353,7 +1383,7 @@ class TestGroupedDense:
...
@@ -1353,7 +1383,7 @@ class TestGroupedDense:
out
=
grouped_dense
(
out
=
grouped_dense
(
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
=
bias
,
quantizer_set
=
quantizer_set
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
=
bias
,
quantizer_set
=
quantizer_set
)
)
return
jnp
.
sum
(
jnp
.
asarray
(
out
))
return
jnp
.
sum
(
jnp
.
asarray
(
out
))
/
jnp
.
sqrt
(
x
.
size
)
@
pytest_parametrize_wrapper
(
"dtype"
,
[
jnp
.
bfloat16
,
jnp
.
float16
])
@
pytest_parametrize_wrapper
(
"dtype"
,
[
jnp
.
bfloat16
,
jnp
.
float16
])
def
test_grouped_dense_grad_fp16
(
self
,
dtype
,
input_shape
):
def
test_grouped_dense_grad_fp16
(
self
,
dtype
,
input_shape
):
...
@@ -1388,9 +1418,6 @@ class TestGroupedDense:
...
@@ -1388,9 +1418,6 @@ class TestGroupedDense:
)
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
def
test_grouped_dense_grad_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
input_shape
):
def
test_grouped_dense_grad_fp8
(
self
,
fwd_bwd_dtype
,
scaling_mode
,
input_shape
):
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
pytest
.
skip
(
"MXFP8 is not supported in grouped_dense yet"
)
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
fwd_dtype
,
bwd_dtype
=
fwd_bwd_dtype
dtype
=
jnp
.
bfloat16
dtype
=
jnp
.
bfloat16
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
=
self
.
_generate_grouped_dense_input
(
x
,
kernel
,
group_sizes
,
contracting_dims
,
bias
=
self
.
_generate_grouped_dense_input
(
...
...
tests/jax/test_distributed_layernorm.py
View file @
44740c6c
...
@@ -75,8 +75,6 @@ class TestDistributedLayernorm:
...
@@ -75,8 +75,6 @@ class TestDistributedLayernorm:
all_reduce_loss_bytes
+
weight_count
*
shape
[
-
1
]
*
jax_dtype
.
itemsize
all_reduce_loss_bytes
+
weight_count
*
shape
[
-
1
]
*
jax_dtype
.
itemsize
)
)
other_bytes
=
0
other_bytes
=
0
if
fp8_recipe
==
recipe
.
MXFP8BlockScaling
()
and
"dp"
in
mesh_axes
:
other_bytes
=
384
# required for small scale shapes that require padding
if
fp8_recipe
==
recipe
.
Float8CurrentScaling
():
if
fp8_recipe
==
recipe
.
Float8CurrentScaling
():
allreduce_total_bytes
+=
jax_dtype
.
itemsize
# 1 * dtype for the amax reduction
allreduce_total_bytes
+=
jax_dtype
.
itemsize
# 1 * dtype for the amax reduction
return
generate_collectives_count
(
return
generate_collectives_count
(
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
44740c6c
...
@@ -13,6 +13,7 @@ from utils import (
...
@@ -13,6 +13,7 @@ from utils import (
assert_tree_like_allclose
,
assert_tree_like_allclose
,
is_devices_enough
,
is_devices_enough
,
pytest_parametrize_wrapper
,
pytest_parametrize_wrapper
,
use_jax_gemm
,
)
)
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
...
@@ -33,6 +34,7 @@ from transformer_engine.jax.sharding import (
...
@@ -33,6 +34,7 @@ from transformer_engine.jax.sharding import (
)
)
from
transformer_engine.jax.sharding
import
MeshResource
from
transformer_engine.jax.sharding
import
MeshResource
from
transformer_engine.jax.quantize
import
QuantizerFactory
from
transformer_engine.jax.quantize
import
QuantizerFactory
from
transformer_engine.jax.cpp_extensions.misc
import
get_min_device_compute_capability
is_fp8_supported
,
reason
=
is_fp8_available
()
is_fp8_supported
,
reason
=
is_fp8_available
()
...
@@ -146,7 +148,15 @@ class TestDistributedLayernormMLP:
...
@@ -146,7 +148,15 @@ class TestDistributedLayernormMLP:
)
)
def
_test_layernorm_mlp_grad
(
def
_test_layernorm_mlp_grad
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
use_shardy
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
use_shardy
,
with_jax_gemm
,
):
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
...
@@ -156,6 +166,8 @@ class TestDistributedLayernormMLP:
...
@@ -156,6 +166,8 @@ class TestDistributedLayernormMLP:
input_shape
,
activation_type
,
use_bias
,
dtype
input_shape
,
activation_type
,
use_bias
,
dtype
)
)
static_inputs
=
[
layernorm_type
,
activation_type
]
static_inputs
=
[
layernorm_type
,
activation_type
]
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
value_and_grad_func
=
jax
.
value_and_grad
(
value_and_grad_func
=
jax
.
value_and_grad
(
self
.
layernorm_fp8_mlp_prim_func
,
argnums
=
range
(
len
(
inputs
))
self
.
layernorm_fp8_mlp_prim_func
,
argnums
=
range
(
len
(
inputs
))
)
)
...
@@ -171,7 +183,9 @@ class TestDistributedLayernormMLP:
...
@@ -171,7 +183,9 @@ class TestDistributedLayernormMLP:
# Multi GPUs
# Multi GPUs
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
None
,
"tp"
))
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
None
,
"tp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
...
@@ -203,25 +217,32 @@ class TestDistributedLayernormMLP:
...
@@ -203,25 +217,32 @@ class TestDistributedLayernormMLP:
value_and_grad_func
,
value_and_grad_func
,
in_shardings
=
in_shardings
,
in_shardings
=
in_shardings
,
out_shardings
=
out_shardings
,
out_shardings
=
out_shardings
,
static_argnums
=
range
(
len
(
multi_inputs
),
len
(
static_inputs
)
+
len
(
multi_inputs
)
+
1
),
static_argnums
=
range
(
len
(
multi_inputs
),
len
(
static_inputs
)
+
len
(
multi_inputs
)
+
1
),
)
# +1 for multi_gpus
)
# +1 for multi_gpus
multi_fwd
,
multi_grads
=
multi_jitter
(
*
multi_inputs
,
*
static_inputs
,
True
)
multi_fwd
,
multi_grads
=
multi_jitter
(
*
multi_inputs
,
*
static_inputs
,
True
)
assert_allclose
(
multi_fwd
,
single_fwd
,
dtype
=
dtype
)
fwd_test_type
=
dtype
if
fp8_recipe
is
None
else
jnp
.
float8_e4m3fn
bwd_test_type
=
dtype
if
fp8_recipe
is
None
else
jnp
.
float8_e5m2
assert_allclose
(
multi_fwd
,
single_fwd
,
dtype
=
fwd_test_type
)
for
i
in
range
(
len
(
inputs
)):
for
i
in
range
(
len
(
inputs
)):
if
multi_grads
[
i
]
is
not
None
:
if
multi_grads
[
i
]
is
not
None
:
if
isinstance
(
multi_grads
[
i
],
list
):
if
isinstance
(
multi_grads
[
i
],
list
):
assert
isinstance
(
single_grads
[
i
],
list
)
assert
isinstance
(
single_grads
[
i
],
list
)
for
m_grad
,
s_grad
in
zip
(
multi_grads
[
i
],
single_grads
[
i
]):
for
m_grad
,
s_grad
in
zip
(
multi_grads
[
i
],
single_grads
[
i
]):
assert_allclose
(
assert_allclose
(
m_grad
,
s_grad
,
dtype
=
dtype
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
m_grad
,
s_grad
,
dtype
=
bwd_test_type
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
)
else
:
else
:
assert_allclose
(
assert_allclose
(
multi_grads
[
i
],
multi_grads
[
i
],
single_grads
[
i
],
single_grads
[
i
],
dtype
=
d
type
,
dtype
=
bwd_test_
type
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
)
...
@@ -232,8 +253,16 @@ class TestDistributedLayernormMLP:
...
@@ -232,8 +253,16 @@ class TestDistributedLayernormMLP:
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_grad
(
def
test_layernorm_mlp_grad
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
with_jax_gemm
,
):
):
self
.
_test_layernorm_mlp_grad
(
self
.
_test_layernorm_mlp_grad
(
mesh_config
,
mesh_config
,
...
@@ -243,6 +272,7 @@ class TestDistributedLayernormMLP:
...
@@ -243,6 +272,7 @@ class TestDistributedLayernormMLP:
dtype
,
dtype
,
fp8_recipe
,
fp8_recipe
,
use_shardy
=
False
,
use_shardy
=
False
,
with_jax_gemm
=
with_jax_gemm
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
...
@@ -251,19 +281,29 @@ class TestDistributedLayernormMLP:
...
@@ -251,19 +281,29 @@ class TestDistributedLayernormMLP:
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_grad_shardy
(
def
test_layernorm_mlp_grad_shardy
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
with_jax_gemm
,
):
):
# We don't test block scaling with Shardy because at the time of writ
ing
,
if
with_jax_gemm
and
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScal
ing
):
# it is not supported in JAX's scaled_matmul_stablehlo.
pytest
.
skip
(
"`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
self
.
_test_layernorm_mlp_grad
(
self
.
_test_layernorm_mlp_grad
(
mesh_config
,
mesh_config
,
activation_type
,
activation_type
,
use_bias
,
use_bias
,
input_shape
,
input_shape
,
dtype
,
dtype
,
fp8_recipe
=
recipe
.
DelayedScaling
()
,
fp8_recipe
=
fp8_
recipe
,
use_shardy
=
True
,
use_shardy
=
True
,
with_jax_gemm
=
with_jax_gemm
,
)
)
def
_test_layernorm_mlp
(
def
_test_layernorm_mlp
(
...
@@ -276,6 +316,7 @@ class TestDistributedLayernormMLP:
...
@@ -276,6 +316,7 @@ class TestDistributedLayernormMLP:
use_fp8
,
use_fp8
,
fp8_recipe
,
fp8_recipe
,
use_shardy
,
use_shardy
,
with_jax_gemm
,
):
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
batch
,
seqlen
,
hidden_in
=
input_shape
batch
,
seqlen
,
hidden_in
=
input_shape
...
@@ -287,6 +328,7 @@ class TestDistributedLayernormMLP:
...
@@ -287,6 +328,7 @@ class TestDistributedLayernormMLP:
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
batch
,
seqlen
,
hidden_in
),
dtype
)
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
batch
,
seqlen
,
hidden_in
),
dtype
)
init_rngs
=
{
"params"
:
subkeys
[
1
]}
init_rngs
=
{
"params"
:
subkeys
[
1
]}
with
use_jax_gemm
(
enabled
=
with_jax_gemm
):
# Single GPUs
# Single GPUs
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
with
fp8_autocast
(
enabled
=
use_fp8
,
fp8_recipe
=
fp8_recipe
):
ln_mlp_single
=
LayerNormMLP
(
ln_mlp_single
=
LayerNormMLP
(
...
@@ -333,16 +375,48 @@ class TestDistributedLayernormMLP:
...
@@ -333,16 +375,48 @@ class TestDistributedLayernormMLP:
# Make sure params values are the same
# Make sure params values are the same
assert_tree_like_allclose
(
params_sharded
[
"params"
],
params_single
[
"params"
])
assert_tree_like_allclose
(
params_sharded
[
"params"
],
params_single
[
"params"
])
assert_allclose
(
ln_out_sharded
,
ln_out_single
,
dtype
=
dtype
)
assert_allclose
(
ln_out_sharded
,
ln_out_single
,
dtype
=
dtype
)
assert_allclose
(
mlp_out_sharded
,
mlp_out_single
,
dtype
=
dtype
)
atol
=
None
rtol
=
None
l40_tolerance_update
=
(
get_min_device_compute_capability
()
==
89
and
fp8_recipe
==
recipe
.
DelayedScaling
()
and
use_fp8
and
dtype
==
jnp
.
float16
and
activation_type
==
(
"gelu"
,)
)
if
l40_tolerance_update
:
atol
=
0.04
rtol
=
11
# JAX's FP8 GEMM, jax.lax.dot_general, now uses the
# Triton backend by default. The error of
# the Triton FP8 gemm has been verified to be less than or equal
# to the error of the cuDNN FP8 gemm w.r.t a float32 ground truth.
# However, Triton can auto-tune a different kernel for the single GPU
# and multi-GPU run in this test, meaning the diff between single GPU
# and multi-GPU can be larger in some cases, even though both are
# within tolerance to the float32 ground truth.
jax_triton_gemm_precision_tolerance_update
=
(
with_jax_gemm
and
isinstance
(
fp8_recipe
,
recipe
.
Float8CurrentScaling
)
and
dtype
==
jnp
.
bfloat16
and
activation_type
==
(
"gelu"
,
"linear"
)
)
if
jax_triton_gemm_precision_tolerance_update
:
atol
=
0.08
rtol
=
15
assert_allclose
(
mlp_out_sharded
,
mlp_out_single
,
dtype
=
dtype
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"
use_shardy
"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"
with_jax_gemm
"
,
[
False
,
True
])
def
test_layernorm_mlp_layer
(
def
test_layernorm_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_shardy
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
with_jax_gemm
):
):
self
.
_test_layernorm_mlp
(
self
.
_test_layernorm_mlp
(
mesh_config
,
mesh_config
,
...
@@ -352,7 +426,8 @@ class TestDistributedLayernormMLP:
...
@@ -352,7 +426,8 @@ class TestDistributedLayernormMLP:
dtype
,
dtype
,
use_fp8
=
False
,
use_fp8
=
False
,
fp8_recipe
=
None
,
fp8_recipe
=
None
,
use_shardy
=
use_shardy
,
use_shardy
=
False
,
with_jax_gemm
=
with_jax_gemm
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
...
@@ -362,8 +437,9 @@ class TestDistributedLayernormMLP:
...
@@ -362,8 +437,9 @@ class TestDistributedLayernormMLP:
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_layer_fp8
(
def
test_layernorm_mlp_layer_fp8
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
with_jax_gemm
):
):
self
.
_test_layernorm_mlp
(
self
.
_test_layernorm_mlp
(
mesh_config
,
mesh_config
,
...
@@ -374,4 +450,51 @@ class TestDistributedLayernormMLP:
...
@@ -374,4 +450,51 @@ class TestDistributedLayernormMLP:
use_fp8
=
True
,
use_fp8
=
True
,
fp8_recipe
=
fp8_recipe
,
fp8_recipe
=
fp8_recipe
,
use_shardy
=
False
,
use_shardy
=
False
,
with_jax_gemm
=
with_jax_gemm
,
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_layer_shardy
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
with_jax_gemm
):
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
False
,
fp8_recipe
=
None
,
use_shardy
=
True
,
with_jax_gemm
=
with_jax_gemm
,
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"with_jax_gemm"
,
[
False
,
True
])
def
test_layernorm_mlp_layer_fp8_shardy
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
with_jax_gemm
):
if
with_jax_gemm
and
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScaling
):
pytest
.
skip
(
"`jax.nn.scaled_matmul()` does not support the Shardy partitioner."
)
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
True
,
fp8_recipe
=
fp8_recipe
,
use_shardy
=
True
,
with_jax_gemm
=
with_jax_gemm
,
)
)
tests/jax/test_helper.py
View file @
44740c6c
...
@@ -92,7 +92,7 @@ class TestFP8Functions(unittest.TestCase):
...
@@ -92,7 +92,7 @@ class TestFP8Functions(unittest.TestCase):
self
.
_check_default_state
()
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
def
test_fp8_autocast_
mxfp8
_scaling
(
self
):
def
test_fp8_autocast_
current
_scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
self
.
_check_default_state
()
...
@@ -116,7 +116,7 @@ class TestFP8Functions(unittest.TestCase):
...
@@ -116,7 +116,7 @@ class TestFP8Functions(unittest.TestCase):
self
.
_check_default_state
()
self
.
_check_default_state
()
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
@
unittest
.
skipIf
(
not
is_mxfp8_supported
,
reason
=
mxfp8_reason
)
def
test_fp8_autocast_mxfp8_scaling
(
self
):
def
test_fp8_autocast_mxfp8_
block_
scaling
(
self
):
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
QuantizeConfig
.
finalize
()
# Ensure the testing not affect by previous tests.
self
.
_check_default_state
()
self
.
_check_default_state
()
...
...
tests/jax/utils.py
View file @
44740c6c
...
@@ -3,11 +3,12 @@
...
@@ -3,11 +3,12 @@
# See LICENSE for license information.
# See LICENSE for license information.
"""Utility for the TE layer tests"""
"""Utility for the TE layer tests"""
import
os
import
functools
import
functools
import
math
import
math
import
operator
import
operator
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
Sequence
,
Union
,
Iterable
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
Tuple
,
Sequence
,
Union
,
Iterable
,
Optional
,
NewType
import
os
from
contextlib
import
contextmanager
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
...
@@ -20,7 +21,6 @@ from jax import random as jax_random
...
@@ -20,7 +21,6 @@ from jax import random as jax_random
import
pytest
import
pytest
from
transformer_engine.jax.attention
import
(
from
transformer_engine.jax.attention
import
(
AttnMaskType
,
canonicalize_attn_mask_type
,
canonicalize_attn_mask_type
,
make_swa_mask
,
make_swa_mask
,
)
)
...
@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType
...
@@ -28,8 +28,8 @@ from transformer_engine.jax.quantize.helper import DType as TEDType
PRNGKey
=
Any
PRNGKey
=
Any
Shape
=
Tuple
[
int
,
...]
Shape
=
Tuple
[
int
,
...]
DType
=
jnp
.
dtype
DType
=
NewType
(
"DType"
,
jnp
.
dtype
)
Array
=
Any
Array
=
NewType
(
"Array"
,
jnp
.
ndarray
)
PrecisionLike
=
Union
[
PrecisionLike
=
Union
[
None
,
str
,
lax
.
Precision
,
Tuple
[
str
,
str
],
Tuple
[
lax
.
Precision
,
lax
.
Precision
]
None
,
str
,
lax
.
Precision
,
Tuple
[
str
,
str
],
Tuple
[
lax
.
Precision
,
lax
.
Precision
]
]
]
...
@@ -1519,7 +1519,7 @@ def dtype_tols(
...
@@ -1519,7 +1519,7 @@ def dtype_tols(
TEDType
.
kFloat8E5M2
:
jnp
.
float8_e5m2
,
TEDType
.
kFloat8E5M2
:
jnp
.
float8_e5m2
,
}[
dtype
]
}[
dtype
]
elif
isinstance
(
dtype
,
np
.
dtype
):
elif
isinstance
(
dtype
,
np
.
dtype
):
dtype
=
jnp
.
dt
ype
(
dtype
)
dtype
=
DT
ype
(
dtype
)
# Expect bit-wise accuracy for integer dtypes
# Expect bit-wise accuracy for integer dtypes
if
not
jnp
.
issubdtype
(
dtype
,
jnp
.
floating
):
if
not
jnp
.
issubdtype
(
dtype
,
jnp
.
floating
):
...
@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
...
@@ -1600,3 +1600,20 @@ def print_debug_tensor_stats(prefix, tensor, hist=False):
fmt
=
fmt
+
"
\n
{}
\n
{}"
fmt
=
fmt
+
"
\n
{}
\n
{}"
jax
.
debug
.
print
(
fmt
,
*
args
)
jax
.
debug
.
print
(
fmt
,
*
args
)
@
contextmanager
def
use_jax_gemm
(
enabled
=
False
):
orig_custom_calls_filter
=
os
.
environ
.
get
(
"NVTE_JAX_CUSTOM_CALLS_RE"
,
None
)
try
:
if
enabled
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS_RE"
]
=
"^(?!GemmPrimitive$).+$"
yield
finally
:
if
enabled
:
if
orig_custom_calls_filter
is
None
:
os
.
environ
.
pop
(
"NVTE_JAX_CUSTOM_CALLS_RE"
)
else
:
os
.
environ
[
"NVTE_JAX_CUSTOM_CALLS_RE"
]
=
orig_custom_calls_filter
tests/pytorch/debug/run_distributed.py
View file @
44740c6c
...
@@ -16,7 +16,7 @@ import transformer_engine
...
@@ -16,7 +16,7 @@ import transformer_engine
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
nvdlfw_inspect.api
as
debug_api
import
nvdlfw_inspect.api
as
debug_api
from
transformer_engine.debug
import
set_weight_tensor_tp_group_reduce
from
transformer_engine.debug
import
set_weight_tensor_tp_group_reduce
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
test_numerics
import
(
from
test_numerics
import
(
_emulate_linear
,
_emulate_linear
,
...
@@ -45,6 +45,8 @@ FEATURE_DIRS = None
...
@@ -45,6 +45,8 @@ FEATURE_DIRS = None
all_boolean
=
[
True
,
False
]
all_boolean
=
[
True
,
False
]
TEST_NR
=
0
TEST_NR
=
0
fp8_available
,
_
=
FP8GlobalStateManager
.
is_fp8_available
()
def
_get_tensors
(
parallel_mode
,
weight_seed
=
SEED
,
data_seed
=
SEED
,
tp_size
=
None
,
tp_rank
=
None
):
def
_get_tensors
(
parallel_mode
,
weight_seed
=
SEED
,
data_seed
=
SEED
,
tp_size
=
None
,
tp_rank
=
None
):
if
tp_size
is
None
:
if
tp_size
is
None
:
...
@@ -221,7 +223,7 @@ def run_debug_test(func):
...
@@ -221,7 +223,7 @@ def run_debug_test(func):
return
wrapper
return
wrapper
CONFIG_LOG_TEST_DISTRIBUTED
=
"""log_distributed:
CONFIG_LOG_TEST_DISTRIBUTED
_FP8
=
"""log_distributed:
layers:
layers:
layer_types: [linear]
layer_types: [linear]
enabled:
enabled:
...
@@ -241,11 +243,27 @@ CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
...
@@ -241,11 +243,27 @@ CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
end_step: 1
end_step: 1
"""
"""
CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8
=
"""log_distributed:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
"""
def
_prepare_config_test_log_distributed
(
config_file
):
def
_prepare_config_test_log_distributed
(
config_file
):
if
WORLD_RANK
!=
0
:
if
WORLD_RANK
!=
0
:
return
return
config_file
.
write
(
CONFIG_LOG_TEST_DISTRIBUTED
)
config_file
.
write
(
CONFIG_LOG_TEST_DISTRIBUTED_FP8
if
fp8_available
else
CONFIG_LOG_TEST_DISTRIBUTED_NO_FP8
)
config_file
.
flush
()
config_file
.
flush
()
...
@@ -361,13 +379,13 @@ def test_log_expert_parallel(**kwargs):
...
@@ -361,13 +379,13 @@ def test_log_expert_parallel(**kwargs):
)
# data parallel
)
# data parallel
model
=
_init_model
(
weight
,
parallel_mode
=
None
,
name
=
"linear1"
)
model
=
_init_model
(
weight
,
parallel_mode
=
None
,
name
=
"linear1"
)
model1
=
_init_model
(
weight
,
parallel_mode
=
None
,
name
=
"linear2"
)
model1
=
_init_model
(
weight
,
parallel_mode
=
None
,
name
=
"linear2"
)
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
Tru
e
,
fp8_recipe
=
FP8_RECIPE
):
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
fp8_availabl
e
,
fp8_recipe
=
FP8_RECIPE
):
y1
=
model
(
x
)
y1
=
model
(
x
)
y2
=
model1
(
x
)
y2
=
model1
(
x
)
y
=
y1
+
y2
y
=
y1
+
y2
y
.
sum
().
backward
()
y
.
sum
().
backward
()
debug_api
.
step
()
debug_api
.
step
()
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
Tru
e
,
fp8_recipe
=
FP8_RECIPE
):
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
fp8_availabl
e
,
fp8_recipe
=
FP8_RECIPE
):
y
=
model
(
x
)
y
=
model
(
x
)
if
WORLD_RANK
!=
0
:
if
WORLD_RANK
!=
0
:
y
=
y
+
model1
(
x
)
y
=
y
+
model1
(
x
)
...
@@ -620,6 +638,7 @@ if __name__ == "__main__":
...
@@ -620,6 +638,7 @@ if __name__ == "__main__":
for
gather_weight
in
[
True
,
False
]:
for
gather_weight
in
[
True
,
False
]:
test_log_distributed
(
parallel_mode
,
gather_weight
)
test_log_distributed
(
parallel_mode
,
gather_weight
)
if
fp8_available
:
for
parallel_mode
in
[
"row"
,
"column"
]:
for
parallel_mode
in
[
"row"
,
"column"
]:
test_disable_fp8_layer
(
parallel_mode
)
test_disable_fp8_layer
(
parallel_mode
)
...
...
tests/pytorch/debug/test_distributed.py
View file @
44740c6c
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
import
os
import
os
import
subprocess
import
subprocess
from
pathlib
import
Path
from
pathlib
import
Path
import
pytest
import
pytest
import
torch
import
torch
...
@@ -21,7 +20,6 @@ import torch
...
@@ -21,7 +20,6 @@ import torch
"""
"""
if
torch
.
cuda
.
device_count
()
<
2
:
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"Distributed training needs at least 2 GPUs."
)
pytest
.
skip
(
"Distributed training needs at least 2 GPUs."
)
...
@@ -34,6 +32,6 @@ def test_debug_distributed(feature_dirs):
...
@@ -34,6 +32,6 @@ def test_debug_distributed(feature_dirs):
test_path
=
TEST_ROOT
/
"run_distributed.py"
test_path
=
TEST_ROOT
/
"run_distributed.py"
test_cmd
=
LAUNCH_CMD
+
[
str
(
test_path
),
f
"--feature_dirs=
{
feature_dirs
[
0
]
}
"
]
test_cmd
=
LAUNCH_CMD
+
[
str
(
test_path
),
f
"--feature_dirs=
{
feature_dirs
[
0
]
}
"
]
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
c
apture_output
=
True
,
check
=
Fals
e
)
result
=
subprocess
.
run
(
test_cmd
,
env
=
os
.
environ
,
c
heck
=
False
,
text
=
Tru
e
)
if
result
.
returncode
!=
0
:
if
result
.
returncode
!=
0
:
raise
AssertionError
(
result
.
stderr
.
de
code
()
)
raise
AssertionError
(
f
"torchrun exited with
{
result
.
return
code
}
"
)
tests/pytorch/debug/test_numerics.py
View file @
44740c6c
...
@@ -27,6 +27,9 @@ from transformer_engine.pytorch.module.base import (
...
@@ -27,6 +27,9 @@ from transformer_engine.pytorch.module.base import (
_2X_ACC_FPROP
,
_2X_ACC_FPROP
,
_2X_ACC_WGRAD
,
_2X_ACC_WGRAD
,
)
)
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
all_boolean
=
[
True
,
False
]
all_boolean
=
[
True
,
False
]
FP8_FORMAT
=
Format
.
HYBRID
FP8_FORMAT
=
Format
.
HYBRID
...
@@ -246,8 +249,8 @@ def _init_model(weight):
...
@@ -246,8 +249,8 @@ def _init_model(weight):
return
model
return
model
def
_run_forward_backward
(
x
,
model
,
loss_scale
=
1.0
,
is_first_microbatch
=
None
):
def
_run_forward_backward
(
x
,
model
,
loss_scale
=
1.0
,
is_first_microbatch
=
None
,
fp8
=
True
):
with
tepytorch
.
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
FP8_RECIPE
):
with
tepytorch
.
fp8_autocast
(
enabled
=
fp8
,
fp8_recipe
=
FP8_RECIPE
):
y
=
model
(
x
,
is_first_microbatch
=
is_first_microbatch
)
y
=
model
(
x
,
is_first_microbatch
=
is_first_microbatch
)
(
y
.
sum
()
*
loss_scale
).
backward
()
(
y
.
sum
()
*
loss_scale
).
backward
()
debug_api
.
step
()
debug_api
.
step
()
...
@@ -262,6 +265,18 @@ def _get_tensors():
...
@@ -262,6 +265,18 @@ def _get_tensors():
return
x
,
weight
return
x
,
weight
LOGGING_CONFIG
=
"""logging_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
"""
DISABLE_FP8_CONFIG
=
Template
(
DISABLE_FP8_CONFIG
=
Template
(
"""disable_fp8_config:
"""disable_fp8_config:
enabled: True
enabled: True
...
@@ -275,10 +290,30 @@ DISABLE_FP8_CONFIG = Template(
...
@@ -275,10 +290,30 @@ DISABLE_FP8_CONFIG = Template(
)
)
@
create_config_file
def
run_logging_zero_numel_tensor
(
feature_dirs
,
**
kwargs
):
kwargs
[
"config_file"
].
write
(
LOGGING_CONFIG
)
kwargs
[
"config_file"
].
flush
()
_init_debug
(
kwargs
[
"config_file"
].
name
,
kwargs
[
"log_dir"
],
feature_dirs
)
x
,
weight
=
_get_tensors
()
x1
=
x
[:
0
,
:]
model
=
_init_model
(
weight
)
_
=
_run_forward_backward
(
x1
,
model
,
fp8
=
False
)
_
=
_run_forward_backward
(
x
,
model
,
fp8
=
False
)
def
test_logging_zero_numel_tensor
(
feature_dirs
):
run_logging_zero_numel_tensor
(
feature_dirs
)
@
pytest
.
mark
.
parametrize
(
"fprop_fp8"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fprop_fp8"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"dgrad_fp8"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"dgrad_fp8"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"wgrad_fp8"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"wgrad_fp8"
,
all_boolean
)
def
test_disable_fp8_gemms
(
feature_dirs
,
fprop_fp8
,
dgrad_fp8
,
wgrad_fp8
):
def
test_disable_fp8_gemms
(
feature_dirs
,
fprop_fp8
,
dgrad_fp8
,
wgrad_fp8
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
run_disable_fp8_gemms
(
feature_dirs
,
fprop_fp8
,
dgrad_fp8
,
wgrad_fp8
)
run_disable_fp8_gemms
(
feature_dirs
,
fprop_fp8
,
dgrad_fp8
,
wgrad_fp8
)
...
@@ -318,6 +353,8 @@ def run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8, **kwarg
...
@@ -318,6 +353,8 @@ def run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8, **kwarg
def
test_disable_fp8_layer
(
feature_dirs
):
def
test_disable_fp8_layer
(
feature_dirs
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
run_disable_fp8_layer
(
feature_dirs
)
run_disable_fp8_layer
(
feature_dirs
)
...
@@ -363,6 +400,8 @@ subset_combinations = random.sample(all_combinations, 20)
...
@@ -363,6 +400,8 @@ subset_combinations = random.sample(all_combinations, 20)
def
test_per_tensor_scaling
(
def
test_per_tensor_scaling
(
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
):
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
not
any
([
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
]):
if
not
any
([
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
]):
pytest
.
skip
(
"Skipping test because all parameters are False"
)
pytest
.
skip
(
"Skipping test because all parameters are False"
)
run_per_tensor_scaling
(
run_per_tensor_scaling
(
...
@@ -535,6 +574,8 @@ def run_per_tensor_scaling(
...
@@ -535,6 +574,8 @@ def run_per_tensor_scaling(
def
test_microbatching_per_tensor_scaling
(
def
test_microbatching_per_tensor_scaling
(
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
):
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
not
any
([
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
]):
if
not
any
([
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
]):
pytest
.
skip
(
"Skipping test because all parameters are False"
)
pytest
.
skip
(
"Skipping test because all parameters are False"
)
...
@@ -624,6 +665,8 @@ subset_combinations = random.sample(all_combinations, 10)
...
@@ -624,6 +665,8 @@ subset_combinations = random.sample(all_combinations, 10)
def
test_fake_quant_fp8
(
def
test_fake_quant_fp8
(
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
):
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
run_fake_quant_fp8
(
run_fake_quant_fp8
(
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
feature_dirs
,
fprop_inp
,
fprop_weight
,
dgrad_weight
,
dgrad_grad
,
wgrad_input
,
wgrad_grad
)
)
...
...
tests/pytorch/debug/test_sanity.py
View file @
44740c6c
...
@@ -2,27 +2,17 @@
...
@@ -2,27 +2,17 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
import
functools
import
itertools
import
os
import
random
import
tempfile
from
string
import
Template
import
pytest
import
pytest
import
torch
import
torch
import
nvdlfw_inspect.api
as
debug_api
import
nvdlfw_inspect.api
as
debug_api
import
transformer_engine.debug
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.common.recipe
import
DelayedScaling
,
Format
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
_default_sf_compute
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
test_numerics
import
create_config_file
from
test_numerics
import
create_config_file
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
B
,
S
,
H
,
D
=
64
,
64
,
64
,
64
B
,
S
,
H
,
D
=
64
,
64
,
64
,
64
model_keys
=
[
"linear"
,
"layernorm_linear"
,
"layernorm_mlp"
,
"mha_attention"
,
"transformer_layer"
]
model_keys
=
[
"linear"
,
"layernorm_linear"
,
"layernorm_mlp"
,
"mha_attention"
,
"transformer_layer"
]
...
@@ -104,4 +94,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
...
@@ -104,4 +94,6 @@ def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
@
pytest
.
mark
.
parametrize
(
"fp8"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"fp8"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"config_key"
,
configs
.
keys
())
@
pytest
.
mark
.
parametrize
(
"config_key"
,
configs
.
keys
())
def
test_sanity_debug
(
model_key
,
fp8
,
config_key
,
feature_dirs
):
def
test_sanity_debug
(
model_key
,
fp8
,
config_key
,
feature_dirs
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
_run_test
(
model_key
,
fp8
,
configs
[
config_key
],
feature_dirs
)
_run_test
(
model_key
,
fp8
,
configs
[
config_key
],
feature_dirs
)
tests/pytorch/distributed/run_numerics.py
View file @
44740c6c
...
@@ -48,11 +48,6 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
...
@@ -48,11 +48,6 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
)
)
# Disable TF32
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
False
torch
.
backends
.
cudnn
.
allow_tf32
=
False
# Quantization recipe setup
# Quantization recipe setup
def
quantization_recipe
()
->
Recipe
:
def
quantization_recipe
()
->
Recipe
:
if
QUANTIZATION
==
"fp8"
:
if
QUANTIZATION
==
"fp8"
:
...
@@ -167,7 +162,7 @@ def _gather(tensor, dim=0):
...
@@ -167,7 +162,7 @@ def _gather(tensor, dim=0):
def
_constant
(
tensor
):
def
_constant
(
tensor
):
return
nn
.
init
.
constant_
(
tensor
,
0.5
)
return
nn
.
init
.
constant_
(
tensor
,
0.
0
5
)
def
dist_print
(
msg
,
src
=
None
,
end
=
"
\n
"
,
error
=
False
):
def
dist_print
(
msg
,
src
=
None
,
end
=
"
\n
"
,
error
=
False
):
...
@@ -190,7 +185,8 @@ def _get_tolerances(dtype):
...
@@ -190,7 +185,8 @@ def _get_tolerances(dtype):
if
dtype
==
torch
.
bfloat16
:
if
dtype
==
torch
.
bfloat16
:
return
{
"rtol"
:
1.6e-2
,
"atol"
:
1e-5
}
return
{
"rtol"
:
1.6e-2
,
"atol"
:
1e-5
}
if
dtype
==
torch
.
float32
:
if
dtype
==
torch
.
float32
:
return
{
"rtol"
:
1.2e-4
,
"atol"
:
1e-4
}
# TF32 has same mantissa bits as FP16
return
{
"rtol"
:
1e-3
,
"atol"
:
1e-5
}
raise
ValueError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
raise
ValueError
(
f
"Unsupported dtype (
{
dtype
}
)"
)
...
@@ -521,8 +517,11 @@ def test_linear():
...
@@ -521,8 +517,11 @@ def test_linear():
{
"return_bias"
:
True
},
{
"return_bias"
:
True
},
{
"params_dtype"
:
torch
.
float16
},
{
"params_dtype"
:
torch
.
float16
},
{
"delay_wgrad_compute"
:
True
},
{
"delay_wgrad_compute"
:
True
},
{
"save_original_input"
:
True
},
]
]
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
if
kwargs
.
get
(
"save_original_input"
,
False
)
and
QUANTIZATION
==
"fp8"
:
continue
for
parallel_mode
in
[
"column"
,
"row"
]:
for
parallel_mode
in
[
"column"
,
"row"
]:
for
sequence_parallel
in
[
False
,
True
]:
for
sequence_parallel
in
[
False
,
True
]:
_test_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
_test_linear
(
parallel_mode
,
sequence_parallel
,
**
kwargs
)
...
...
tests/pytorch/distributed/test_fusible_ops.py
View file @
44740c6c
...
@@ -28,7 +28,6 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
...
@@ -28,7 +28,6 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
import
transformer_engine.pytorch.ops
as
te_ops
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
...
...
tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
View file @
44740c6c
...
@@ -21,7 +21,6 @@ import transformer_engine.pytorch as te
...
@@ -21,7 +21,6 @@ import transformer_engine.pytorch as te
import
transformer_engine.pytorch.cpp_extensions
as
tex
import
transformer_engine.pytorch.cpp_extensions
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
from
transformer_engine.pytorch.ops.fused
import
(
from
transformer_engine.pytorch.ops.fused
import
(
UserbuffersBackwardLinear
,
UserbuffersBackwardLinear
,
UserbuffersForwardLinear
,
UserbuffersForwardLinear
,
...
@@ -32,6 +31,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
...
@@ -32,6 +31,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.quantized_tensor
import
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
# Import utility functions
# Import utility functions
...
@@ -370,7 +370,7 @@ def _test_linear(
...
@@ -370,7 +370,7 @@ def _test_linear(
if
quantized_compute
:
if
quantized_compute
:
tols
=
dtype_tols
(
tols
=
dtype_tols
(
model
[
0
].
weight
.
_fp8_dtype
model
[
0
].
weight
.
_fp8_dtype
if
is
_float8_tensor
(
model
[
0
].
weight
)
if
is
instance
(
model
[
0
].
weight
,
Float8Tensor
)
else
tex
.
DType
.
kFloat8E4M3
else
tex
.
DType
.
kFloat8E4M3
)
)
...
...
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
View file @
44740c6c
...
@@ -89,7 +89,7 @@ def run_dpa_with_cp(
...
@@ -89,7 +89,7 @@ def run_dpa_with_cp(
# instantiate core attn module
# instantiate core attn module
core_attn
=
DotProductAttention
(
core_attn
=
DotProductAttention
(
config
.
num_heads
,
config
.
num_heads
,
config
.
head_dim_qk
,
(
config
.
head_dim_qk
,
config
.
head_dim_v
),
num_gqa_groups
=
config
.
num_gqa_groups
,
num_gqa_groups
=
config
.
num_gqa_groups
,
attention_dropout
=
config
.
dropout_p
,
attention_dropout
=
config
.
dropout_p
,
qkv_format
=
qkv_format
,
qkv_format
=
qkv_format
,
...
@@ -106,16 +106,22 @@ def run_dpa_with_cp(
...
@@ -106,16 +106,22 @@ def run_dpa_with_cp(
config
.
num_heads
,
config
.
num_heads
,
config
.
head_dim_qk
,
config
.
head_dim_qk
,
)
)
k
v
_input_shape
=
(
k_input_shape
=
(
config
.
batch_size
,
config
.
batch_size
,
config
.
max_seqlen_kv
,
config
.
max_seqlen_kv
,
config
.
num_gqa_groups
,
config
.
num_gqa_groups
,
config
.
head_dim_qk
,
config
.
head_dim_qk
,
)
)
v_input_shape
=
(
config
.
batch_size
,
config
.
max_seqlen_kv
,
config
.
num_gqa_groups
,
config
.
head_dim_v
,
)
attn_output_shape
=
(
attn_output_shape
=
(
config
.
batch_size
,
config
.
batch_size
,
config
.
max_seqlen_q
,
config
.
max_seqlen_q
,
config
.
num_heads
*
config
.
head_dim_
qk
,
config
.
num_heads
*
config
.
head_dim_
v
,
)
)
cu_seqlens_q
=
None
cu_seqlens_q
=
None
cu_seqlens_kv
=
None
cu_seqlens_kv
=
None
...
@@ -128,16 +134,22 @@ def run_dpa_with_cp(
...
@@ -128,16 +134,22 @@ def run_dpa_with_cp(
config
.
num_heads
,
config
.
num_heads
,
config
.
head_dim_qk
,
config
.
head_dim_qk
,
)
)
k
v
_input_shape
=
(
k_input_shape
=
(
config
.
max_seqlen_kv
,
config
.
max_seqlen_kv
,
config
.
batch_size
,
config
.
batch_size
,
config
.
num_gqa_groups
,
config
.
num_gqa_groups
,
config
.
head_dim_qk
,
config
.
head_dim_qk
,
)
)
v_input_shape
=
(
config
.
max_seqlen_kv
,
config
.
batch_size
,
config
.
num_gqa_groups
,
config
.
head_dim_v
,
)
attn_output_shape
=
(
attn_output_shape
=
(
config
.
max_seqlen_q
,
config
.
max_seqlen_q
,
config
.
batch_size
,
config
.
batch_size
,
config
.
num_heads
*
config
.
head_dim_
qk
,
config
.
num_heads
*
config
.
head_dim_
v
,
)
)
cu_seqlens_q
=
None
cu_seqlens_q
=
None
cu_seqlens_kv
=
None
cu_seqlens_kv
=
None
...
@@ -149,14 +161,19 @@ def run_dpa_with_cp(
...
@@ -149,14 +161,19 @@ def run_dpa_with_cp(
config
.
num_heads
,
config
.
num_heads
,
config
.
head_dim_qk
,
config
.
head_dim_qk
,
)
)
k
v
_input_shape
=
(
k_input_shape
=
(
config
.
batch_size
*
config
.
max_seqlen_q
,
config
.
batch_size
*
config
.
max_seqlen_q
,
config
.
num_gqa_groups
,
config
.
num_gqa_groups
,
config
.
head_dim_qk
,
config
.
head_dim_qk
,
)
)
v_input_shape
=
(
config
.
batch_size
*
config
.
max_seqlen_q
,
config
.
num_gqa_groups
,
config
.
head_dim_v
,
)
attn_output_shape
=
(
attn_output_shape
=
(
config
.
batch_size
*
config
.
max_seqlen_q
,
config
.
batch_size
*
config
.
max_seqlen_q
,
config
.
num_heads
*
config
.
head_dim_
qk
,
config
.
num_heads
*
config
.
head_dim_
v
,
)
)
seqlens_q
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
+
1
,
[
config
.
batch_size
]).
to
(
torch
.
int32
)
seqlens_q
=
torch
.
randint
(
0
,
config
.
max_seqlen_q
+
1
,
[
config
.
batch_size
]).
to
(
torch
.
int32
)
seqlens_q_padded
=
(
seqlens_q
+
2
*
world_size
-
1
)
//
(
world_size
*
2
)
*
(
world_size
*
2
)
seqlens_q_padded
=
(
seqlens_q
+
2
*
world_size
-
1
)
//
(
world_size
*
2
)
*
(
world_size
*
2
)
...
@@ -177,8 +194,8 @@ def run_dpa_with_cp(
...
@@ -177,8 +194,8 @@ def run_dpa_with_cp(
assert
False
,
f
"
{
qkv_format
}
is an unsupported qkv_format!"
assert
False
,
f
"
{
qkv_format
}
is an unsupported qkv_format!"
q
=
torch
.
randn
(
q_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
q
=
torch
.
randn
(
q_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
k
=
torch
.
randn
(
k
v
_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
k
=
torch
.
randn
(
k_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
v
=
torch
.
randn
(
k
v_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
v
=
torch
.
randn
(
v_input_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
dout
=
torch
.
randn
(
attn_output_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
dout
=
torch
.
randn
(
attn_output_shape
,
dtype
=
dtypes
[
dtype
]).
cuda
()
dout_quantizer
=
Float8Quantizer
(
dout_quantizer
=
Float8Quantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
,
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
,
...
...
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
View file @
44740c6c
...
@@ -174,6 +174,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
...
@@ -174,6 +174,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest
.
skip
(
"Only fp8 works with fp8_mha=True!"
)
pytest
.
skip
(
"Only fp8 works with fp8_mha=True!"
)
if
"p2p"
not
in
cp_comm_type
and
config
.
head_dim_qk
!=
config
.
head_dim_v
:
if
"p2p"
not
in
cp_comm_type
and
config
.
head_dim_qk
!=
config
.
head_dim_v
:
pytest
.
skip
(
"MLA CP currently only support KV P2P!"
)
pytest
.
skip
(
"MLA CP currently only support KV P2P!"
)
if
dtype
==
"fp8"
and
config
.
head_dim_qk
!=
config
.
head_dim_v
:
pytest
.
skip
(
"MLA CP currently does not support FP8 attention!"
)
subprocess
.
run
(
subprocess
.
run
(
get_bash_arguments
(
get_bash_arguments
(
...
...
tests/pytorch/test_checkpoint.py
0 → 100644
View file @
44740c6c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
__future__
import
annotations
import
argparse
import
functools
import
os
import
pathlib
import
pytest
import
torch
import
transformer_engine.pytorch
as
te
from
utils
import
make_recipe
# Check supported quantization schemes
fp8_available
,
reason_for_no_fp8
=
te
.
fp8
.
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
te
.
fp8
.
FP8GlobalStateManager
.
is_mxfp8_available
()
# Test cases for loading checkpoint files
_TestLoadCheckpoint_name_list
:
tuple
[
str
,
...]
=
(
"linear"
,
"layernorm_linear"
,
"layernorm_mlp"
,
"layernorm"
,
"rmsnorm"
,
"transformer_layer"
,
"ops_linear"
,
"linear.fp8"
,
"ops_linear.fp8"
,
"linear.mxfp8"
,
"ops_linear.mxfp8"
,
)
class
TestLoadCheckpoint
:
"""Tests for loading checkpoint files
Tests assume that checkpoint files have already been created. In
order to regenerate checkpoint files, e.g. after a breaking change
in the checkpoint format, run this file directly as a Python
script: `python3 test_checkpoint.py --save-checkpoint all`.
"""
@
staticmethod
def
_make_module
(
name
:
str
)
->
torch
.
nn
.
Module
:
"""Construct a module"""
if
name
==
"linear"
:
return
te
.
Linear
(
1
,
1
)
if
name
==
"layernorm_linear"
:
return
te
.
LayerNormLinear
(
1
,
1
)
if
name
==
"layernorm_mlp"
:
return
te
.
LayerNormMLP
(
1
,
1
)
if
name
==
"layernorm"
:
return
te
.
LayerNorm
(
1
)
if
name
==
"rmsnorm"
:
return
te
.
RMSNorm
(
1
)
if
name
==
"transformer_layer"
:
return
te
.
TransformerLayer
(
1
,
1
,
1
)
if
name
==
"ops_linear"
:
return
te
.
ops
.
Linear
(
1
,
1
)
if
name
==
"linear.fp8"
:
with
te
.
fp8_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
return
te
.
Linear
(
16
,
16
)
if
name
==
"ops_linear.fp8"
:
with
te
.
fp8_model_init
(
recipe
=
make_recipe
(
"fp8"
)):
return
te
.
ops
.
Linear
(
16
,
16
)
if
name
==
"linear.mxfp8"
:
with
te
.
fp8_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
return
te
.
Linear
(
32
,
32
)
if
name
==
"ops_linear.mxfp8"
:
with
te
.
fp8_model_init
(
recipe
=
make_recipe
(
"mxfp8"
)):
return
te
.
ops
.
Linear
(
32
,
32
)
raise
ValueError
(
f
"Unrecognized module name (
{
name
}
)"
)
@
staticmethod
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_checkpoint_dir
()
->
pathlib
.
Path
:
"""Path to directory with checkpoint files"""
# Check environment variable
path
=
os
.
getenv
(
"NVTE_TEST_CHECKPOINT_ARTIFACT_PATH"
)
if
path
:
return
pathlib
.
Path
(
path
).
resolve
()
# Fallback to path in root dir
root_dir
=
pathlib
.
Path
(
__file__
).
resolve
().
parent
.
parent
.
parent
return
root_dir
/
"artifacts"
/
"tests"
/
"pytorch"
/
"test_checkpoint"
@
staticmethod
def
_save_checkpoint
(
name
:
str
,
checkpoint_dir
:
Optional
[
pathlib
.
Path
]
=
None
)
->
None
:
"""Save a module's checkpoint file"""
# Path to save checkpoint
if
checkpoint_dir
is
None
:
checkpoint_dir
=
TestLoadCheckpoint
.
_checkpoint_dir
()
checkpoint_dir
.
mkdir
(
exist_ok
=
True
)
checkpoint_file
=
checkpoint_dir
/
f
"
{
name
}
.pt"
# Create module and save checkpoint
module
=
TestLoadCheckpoint
.
_make_module
(
name
)
torch
.
save
(
module
.
state_dict
(),
checkpoint_file
)
print
(
f
"Saved checkpoint for
{
name
}
at
{
checkpoint_file
}
"
)
@
pytest
.
mark
.
parametrize
(
"name"
,
_TestLoadCheckpoint_name_list
)
def
test_module
(
self
,
name
:
str
)
->
None
:
"""Test for loading a module's checkpoint file"""
# Skip if quantization is not supported
quantization
=
None
if
"."
in
name
:
quantization
=
name
.
split
(
"."
)[
1
]
if
quantization
==
"fp8"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
# Construct module
module
=
self
.
_make_module
(
name
)
# Load checkpoint from file
checkpoint_file
=
self
.
_checkpoint_dir
()
/
f
"
{
name
}
.pt"
if
not
checkpoint_file
.
is_file
():
raise
FileNotFoundError
(
f
"Could not find checkpoint file at
{
checkpoint_file
}
"
)
state_dict
=
torch
.
load
(
checkpoint_file
,
weights_only
=
False
)
# Update module from checkpoint
module
.
load_state_dict
(
state_dict
,
strict
=
True
)
def
main
()
->
None
:
"""Main function
Typically used to generate checkpoint files.
"""
# Parse command-line arguments
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--save-checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Save checkpoint file for a module"
,
)
parser
.
add_argument
(
"--checkpoint-dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory to save checkpoint file in"
,
)
args
=
parser
.
parse_args
()
# Save checkpoint files if needed
if
args
.
save_checkpoint
is
not
None
:
checkpoint_dir
=
args
.
checkpoint_dir
if
checkpoint_dir
is
not
None
:
checkpoint_dir
=
pathlib
.
Path
(
checkpoint_dir
).
resolve
()
if
args
.
save_checkpoint
==
"all"
:
for
name
in
_TestLoadCheckpoint_name_list
:
TestLoadCheckpoint
.
_save_checkpoint
(
name
,
checkpoint_dir
=
checkpoint_dir
)
else
:
TestLoadCheckpoint
.
_save_checkpoint
(
args
.
save_checkpoint
,
checkpoint_dir
=
checkpoint_dir
,
)
if
__name__
==
"__main__"
:
main
()
tests/pytorch/test_fused_router.py
0 → 100644
View file @
44740c6c
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
torch
import
math
from
typing
import
Optional
,
Dict
from
transformer_engine.pytorch.router
import
(
fused_topk_with_score_function
,
fused_compute_score_for_moe_aux_loss
,
fused_moe_aux_loss
,
)
import
pytest
from
copy
import
deepcopy
seed
=
42
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
# Pytorch-based group topk
def
group_limited_topk
(
scores
:
torch
.
Tensor
,
topk
:
int
,
num_tokens
:
int
,
num_experts
:
int
,
num_groups
:
int
,
group_topk
:
int
,
):
group_scores
=
(
scores
.
view
(
num_tokens
,
num_groups
,
-
1
).
topk
(
topk
//
group_topk
,
dim
=-
1
)[
0
].
sum
(
dim
=-
1
)
)
group_idx
=
torch
.
topk
(
group_scores
,
k
=
group_topk
,
dim
=-
1
,
sorted
=
False
)[
1
]
group_mask
=
torch
.
zeros_like
(
group_scores
)
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# Mask the experts based on selection groups
score_mask
=
(
group_mask
.
unsqueeze
(
-
1
)
.
expand
(
num_tokens
,
num_groups
,
num_experts
//
num_groups
)
.
reshape
(
num_tokens
,
-
1
)
)
masked_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
float
(
"-inf"
))
probs
,
top_indices
=
torch
.
topk
(
masked_scores
,
k
=
topk
,
dim
=-
1
)
return
probs
,
top_indices
# Pytorch-based topk softmax/sigmoid
def
topk_softmax_sigmoid_pytorch
(
logits
:
torch
.
Tensor
,
topk
:
int
,
use_pre_softmax
:
bool
=
False
,
num_groups
:
Optional
[
int
]
=
None
,
group_topk
:
Optional
[
int
]
=
None
,
scaling_factor
:
Optional
[
float
]
=
None
,
score_function
:
str
=
"softmax"
,
expert_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
num_tokens
,
num_experts
=
logits
.
shape
def
compute_topk
(
scores
,
topk
,
num_groups
=
None
,
group_topk
=
None
):
if
group_topk
:
return
group_limited_topk
(
scores
=
scores
,
topk
=
topk
,
num_tokens
=
num_tokens
,
num_experts
=
num_experts
,
num_groups
=
num_groups
,
group_topk
=
group_topk
,
)
else
:
return
torch
.
topk
(
scores
,
k
=
topk
,
dim
=
1
)
if
score_function
==
"softmax"
:
if
use_pre_softmax
:
scores
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
).
type_as
(
logits
)
probs
,
top_indices
=
compute_topk
(
scores
,
topk
,
num_groups
,
group_topk
)
else
:
scores
,
top_indices
=
compute_topk
(
logits
,
topk
,
num_groups
,
group_topk
)
probs
=
torch
.
softmax
(
scores
,
dim
=-
1
,
dtype
=
torch
.
float32
).
type_as
(
logits
)
elif
score_function
==
"sigmoid"
:
scores
=
torch
.
sigmoid
(
logits
.
float
()).
type_as
(
logits
)
if
expert_bias
is
not
None
:
scores_for_routing
=
scores
+
expert_bias
_
,
top_indices
=
compute_topk
(
scores_for_routing
,
topk
,
num_groups
,
group_topk
)
scores
=
torch
.
gather
(
scores
,
dim
=
1
,
index
=
top_indices
).
type_as
(
logits
)
else
:
scores
,
top_indices
=
compute_topk
(
scores
,
topk
,
num_groups
,
group_topk
)
probs
=
scores
/
(
scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
)
if
topk
>
1
else
scores
else
:
raise
ValueError
(
f
"Invalid score_function:
{
score_function
}
"
)
if
scaling_factor
:
probs
=
probs
*
scaling_factor
topk_masked_gates
=
torch
.
zeros_like
(
logits
).
scatter
(
1
,
top_indices
,
probs
)
topk_map
=
torch
.
zeros_like
(
logits
).
int
().
scatter
(
1
,
top_indices
,
1
).
bool
()
return
topk_masked_gates
,
topk_map
# Pytorch-based compute routing scores for aux loss
def
compute_scores_for_aux_loss_pytorch
(
logits
:
torch
.
Tensor
,
topk
:
int
,
score_function
:
str
)
->
torch
.
Tensor
:
if
score_function
==
"softmax"
:
scores
=
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
elif
score_function
==
"sigmoid"
:
scores
=
torch
.
sigmoid
(
logits
)
scores
=
scores
/
(
scores
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
1e-20
)
if
topk
>
1
else
scores
else
:
raise
ValueError
(
f
"Invalid score_function:
{
score_function
}
"
)
_
,
top_indices
=
torch
.
topk
(
scores
,
k
=
topk
,
dim
=
1
)
routing_map
=
torch
.
zeros_like
(
logits
).
int
().
scatter
(
1
,
top_indices
,
1
).
bool
()
return
routing_map
,
scores
# Pytorch-based aux loss
def
aux_loss_pytorch
(
probs
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
,
total_num_tokens
:
int
,
topk
:
int
,
num_experts
:
int
,
moe_aux_loss_coeff
:
float
,
):
aggregated_probs_per_expert
=
probs
.
sum
(
dim
=
0
)
aux_loss
=
torch
.
sum
(
aggregated_probs_per_expert
*
tokens_per_expert
)
*
(
num_experts
*
moe_aux_loss_coeff
/
(
topk
*
total_num_tokens
*
total_num_tokens
)
)
return
aux_loss
def
run_comparison
(
dtype
,
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
num_groups
,
group_topk
,
scaling_factor
,
score_function
,
enable_bias
,
):
# Set some parameters
if
score_function
==
"sigmoid"
:
# Construct the special logits to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
logits
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
logits
=
logits
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
else
:
logits
=
torch
.
arange
(
num_tokens
*
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-4
logits
=
logits
.
view
(
num_tokens
,
num_experts
)
logits
.
requires_grad
=
True
if
enable_bias
and
score_function
==
"sigmoid"
:
expert_bias
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
)
*
0.1
expert_bias
=
torch
.
flip
(
expert_bias
,
dims
=
[
0
])
expert_bias
.
requires_grad
=
True
else
:
expert_bias
=
None
# Clone the input tensor
logits_clone
=
deepcopy
(
logits
)
logits_clone
.
requires_grad
=
True
if
expert_bias
is
not
None
:
expert_bias_clone
=
deepcopy
(
expert_bias
)
expert_bias_clone
.
requires_grad
=
True
else
:
expert_bias_clone
=
None
# Run the original implementation
# We do not support the capacity factor case
probs
,
routing_map
=
topk_softmax_sigmoid_pytorch
(
logits
=
logits
,
topk
=
topk
,
use_pre_softmax
=
use_pre_softmax
,
num_groups
=
num_groups
,
group_topk
=
group_topk
,
scaling_factor
=
scaling_factor
,
score_function
=
score_function
,
expert_bias
=
expert_bias
,
)
# Run the fused implementation
probs_fused
,
routing_map_fused
=
fused_topk_with_score_function
(
logits
=
logits_clone
,
topk
=
topk
,
use_pre_softmax
=
use_pre_softmax
,
num_groups
=
num_groups
,
group_topk
=
group_topk
,
scaling_factor
=
scaling_factor
,
score_function
=
score_function
,
expert_bias
=
expert_bias_clone
,
)
torch
.
testing
.
assert_close
(
probs
,
probs_fused
)
torch
.
testing
.
assert_close
(
routing_map
,
routing_map_fused
)
# Fake the loss
loss
=
torch
.
sum
(
probs
)
loss_fused
=
torch
.
sum
(
probs_fused
)
# Backward the loss
loss
.
backward
()
loss_fused
.
backward
()
# Check the gradient
torch
.
testing
.
assert_close
(
logits
.
grad
,
logits_clone
.
grad
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
,
7168
,
8992
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
128
,
32
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"group_topk"
,
[
None
,
4
])
@
pytest
.
mark
.
parametrize
(
"scaling_factor"
,
[
None
,
1.2
])
@
pytest
.
mark
.
parametrize
(
"enable_bias"
,
[
True
,
False
])
def
test_topk_sigmoid
(
dtype
,
num_tokens
,
num_experts
,
topk
,
group_topk
,
scaling_factor
,
enable_bias
,
):
num_groups
=
8
if
group_topk
else
None
run_comparison
(
dtype
=
dtype
,
num_tokens
=
num_tokens
,
num_experts
=
num_experts
,
topk
=
topk
,
use_pre_softmax
=
False
,
num_groups
=
num_groups
,
group_topk
=
group_topk
,
scaling_factor
=
scaling_factor
,
score_function
=
"sigmoid"
,
enable_bias
=
enable_bias
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
,
7168
,
14234
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
128
,
32
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"use_pre_softmax"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"group_topk"
,
[
None
,
4
])
@
pytest
.
mark
.
parametrize
(
"scaling_factor"
,
[
None
,
1.2
])
def
test_topk_softmax
(
dtype
,
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
group_topk
,
scaling_factor
,
):
num_groups
=
8
if
group_topk
else
None
run_comparison
(
dtype
=
dtype
,
num_tokens
=
num_tokens
,
num_experts
=
num_experts
,
topk
=
topk
,
use_pre_softmax
=
use_pre_softmax
,
num_groups
=
num_groups
,
group_topk
=
group_topk
,
scaling_factor
=
scaling_factor
,
score_function
=
"softmax"
,
enable_bias
=
False
,
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
,
7168
,
14234
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
256
,
128
,
32
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"score_function"
,
[
"softmax"
,
"sigmoid"
])
def
test_fused_scores_for_aux_loss
(
dtype
,
num_tokens
,
num_experts
,
topk
,
score_function
):
if
score_function
==
"sigmoid"
:
# Construct the special logits to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
logits
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
logits
=
logits
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
else
:
logits
=
torch
.
arange
(
num_tokens
*
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-4
logits
=
logits
.
view
(
num_tokens
,
num_experts
)
logits
.
requires_grad
=
True
logits_clone
=
deepcopy
(
logits
)
logits_clone
.
requires_grad
=
True
routing_map
,
scores
=
compute_scores_for_aux_loss_pytorch
(
logits
=
logits
,
topk
=
topk
,
score_function
=
score_function
,
)
routing_map_fused
,
scores_fused
=
fused_compute_score_for_moe_aux_loss
(
logits
=
logits_clone
,
topk
=
topk
,
score_function
=
score_function
,
)
torch
.
testing
.
assert_close
(
scores
,
scores_fused
)
torch
.
testing
.
assert_close
(
routing_map
,
routing_map_fused
)
loss
=
torch
.
sum
(
scores
)
loss
.
backward
()
loss_fused
=
torch
.
sum
(
scores_fused
)
loss_fused
.
backward
()
torch
.
testing
.
assert_close
(
logits
.
grad
,
logits_clone
.
grad
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
,
7168
,
14234
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
256
,
128
,
32
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
4
])
def
test_fused_moe_aux_loss
(
dtype
,
num_tokens
,
num_experts
,
topk
):
# Construct the special probs to avoid inf in the sigmoid function
offset
=
torch
.
arange
(
0
,
num_tokens
,
dtype
=
dtype
,
device
=
"cuda"
)
*
1e-4
probs
=
torch
.
arange
(
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
*
1e-2
probs
=
probs
.
unsqueeze
(
0
).
repeat
(
num_tokens
,
1
)
+
offset
.
unsqueeze
(
1
)
probs
=
probs
.
view
(
num_tokens
,
num_experts
)
probs
.
requires_grad
=
True
tokens_per_expert
=
torch
.
randint
(
1
,
1000
,
(
num_experts
,),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
coeff
=
0.01
probs_clone
=
deepcopy
(
probs
)
probs_clone
.
requires_grad
=
True
aux_loss
=
aux_loss_pytorch
(
probs
=
probs
,
tokens_per_expert
=
tokens_per_expert
,
total_num_tokens
=
num_tokens
,
topk
=
topk
,
num_experts
=
num_experts
,
moe_aux_loss_coeff
=
coeff
,
)
aux_loss_fused
=
fused_moe_aux_loss
(
probs
=
probs_clone
,
tokens_per_expert
=
tokens_per_expert
,
total_num_tokens
=
num_tokens
,
num_experts
=
num_experts
,
topk
=
topk
,
coeff
=
coeff
,
)
torch
.
testing
.
assert_close
(
aux_loss
,
aux_loss_fused
)
# Backward
aux_loss
.
backward
()
aux_loss_fused
.
backward
()
torch
.
testing
.
assert_close
(
probs
.
grad
,
probs_clone
.
grad
)
def
profile_topk_softmax
(
dtype
,
num_tokens
,
num_experts
,
topk
,
enable_bias
,
use_pre_softmax
,
):
group_topk
=
4
scaling_factor
=
1.2
test_topk_sigmoid
(
torch
.
float32
,
num_tokens
,
num_experts
,
topk
,
group_topk
,
scaling_factor
,
enable_bias
)
test_topk_softmax
(
torch
.
float32
,
num_tokens
,
num_experts
,
topk
,
use_pre_softmax
,
group_topk
,
scaling_factor
)
if
__name__
==
"__main__"
:
test_fused_scores_for_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2
,
num_experts
=
32
,
topk
=
8
,
score_function
=
"softmax"
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
2048
,
num_experts
=
256
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
7168
,
num_experts
=
256
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
32
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
128
,
topk
=
4
)
test_fused_moe_aux_loss
(
dtype
=
torch
.
float32
,
num_tokens
=
14234
,
num_experts
=
256
,
topk
=
4
)
tests/pytorch/test_fusible_ops.py
View file @
44740c6c
...
@@ -20,8 +20,8 @@ import transformer_engine.common.recipe
...
@@ -20,8 +20,8 @@ import transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
import
transformer_engine.pytorch.ops
as
te_ops
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
from
transformer_engine.pytorch.ops.fused
import
(
from
transformer_engine.pytorch.ops.fused
import
(
BackwardBiasActivation
,
BackwardLinearAdd
,
BackwardLinearAdd
,
ForwardLinearBiasActivation
,
ForwardLinearBiasActivation
,
ForwardLinearBiasAdd
,
ForwardLinearBiasAdd
,
...
@@ -162,7 +162,7 @@ def make_reference_and_test_tensors(
...
@@ -162,7 +162,7 @@ def make_reference_and_test_tensors(
return
ref
,
test
return
ref
,
test
class
TestSequential
:
class
TestSequential
Container
:
"""Tests for sequential container"""
"""Tests for sequential container"""
def
test_modules
(
self
)
->
None
:
def
test_modules
(
self
)
->
None
:
...
@@ -1878,6 +1878,98 @@ class TestFusedOps:
...
@@ -1878,6 +1878,98 @@ class TestFusedOps:
db_test
=
model
[
0
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
db_test
=
model
[
0
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
(
"relu"
,
"gelu"
))
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
32
,
32
),
(
32
,
1
,
32
),
(
8
,
2
,
2
,
32
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_backward_bias_activation
(
self
,
*
,
activation
:
str
,
out_shape
:
Iterable
[
int
],
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
)
->
None
:
"""Backward dbias + dact + quantize"""
# Tensor dimensions
in_shape
=
list
(
out_shape
)
hidden_size
=
in_shape
[
-
1
]
# Skip invalid configurations
with_quantization
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
device
=
device
)
if
quantization
==
"mxfp8"
and
(
len
(
in_shape
)
<
2
or
in_shape
[
-
1
]
%
32
!=
0
):
pytest
.
skip
(
"Unsupported tensor size for MXFP8"
)
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
)
b_ref
,
b_test
=
make_reference_and_test_tensors
(
hidden_size
,
test_dtype
=
dtype
,
test_device
=
device
,
)
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Plain PyTorch implementation
y_ref
=
x_ref
+
b_ref
.
reshape
([
1
]
*
(
len
(
in_shape
)
-
1
)
+
[
hidden_size
])
if
activation
==
"gelu"
:
y_ref
=
torch
.
nn
.
functional
.
gelu
(
y_ref
,
approximate
=
"tanh"
)
elif
activation
==
"relu"
:
y_ref
=
torch
.
nn
.
functional
.
relu
(
y_ref
)
else
:
raise
ValueError
(
f
"Unexpected activation function (
{
activation
}
)"
)
y_ref
.
backward
(
dy_ref
)
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
act_type
=
te_ops
.
GELU
if
activation
==
"gelu"
else
te_ops
.
ReLU
model
=
te_ops
.
Sequential
(
te_ops
.
Quantize
(
forward
=
False
,
backward
=
True
),
te_ops
.
Bias
(
hidden_size
,
device
=
device
,
dtype
=
dtype
),
act_type
(),
)
with
torch
.
no_grad
():
model
[
1
].
bias
.
copy_
(
b_test
)
del
b_test
with
te
.
fp8_autocast
(
enabled
=
with_quantization
,
fp8_recipe
=
recipe
):
y_test
=
model
(
x_test
)
y_test
.
backward
(
dy_test
)
# Check that backward operations have been fused
backward_ops
=
model
.
_module_groups
[
0
].
_backward_ops
if
with_quantization
and
quantization
in
[
"fp8_delayed_scaling"
,
"mxfp8"
]:
assert
len
(
backward_ops
)
==
2
assert
isinstance
(
backward_ops
[
0
][
0
],
BackwardBiasActivation
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Quantize
)
else
:
assert
len
(
backward_ops
)
==
3
assert
isinstance
(
backward_ops
[
0
][
0
],
act_type
)
assert
isinstance
(
backward_ops
[
1
][
0
],
te_ops
.
Bias
)
assert
isinstance
(
backward_ops
[
2
][
0
],
te_ops
.
Quantize
)
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
if
with_quantization
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
y_test
=
y_test
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
dx_test
=
x_test
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
db_test
=
model
[
1
].
bias
.
grad
.
to
(
dtype
=
torch
.
float64
,
device
=
"cpu"
)
torch
.
testing
.
assert_close
(
y_test
,
y_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
db_test
,
b_ref
.
grad
,
**
tols
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_backward_linear_add
(
def
test_backward_linear_add
(
...
@@ -2093,3 +2185,109 @@ class TestCheckpointing:
...
@@ -2093,3 +2185,109 @@ class TestCheckpointing:
torch
.
testing
.
assert_close
(
y_load
,
y_save
,
**
tols
)
torch
.
testing
.
assert_close
(
y_load
,
y_save
,
**
tols
)
for
x_load
,
x_save
in
zip
(
xs_load
,
xs_save
):
for
x_load
,
x_save
in
zip
(
xs_load
,
xs_save
):
torch
.
testing
.
assert_close
(
x_load
.
grad
,
x_save
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
x_load
.
grad
,
x_save
.
grad
,
**
tols
)
class
TestSequentialModules
:
"""Test for larger Sequentials with modules commonly used together"""
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"normalization"
,
(
"LayerNorm"
,
"RMSNorm"
))
@
pytest
.
mark
.
parametrize
(
"quantized_compute"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
_quantization_list
)
def
test_layernorm_mlp
(
self
,
*
,
bias
:
bool
,
normalization
:
str
,
quantized_compute
:
bool
,
quantized_weight
:
bool
,
dtype
:
torch
.
dtype
,
quantization
:
Optional
[
str
],
device
:
torch
.
device
=
"cuda"
,
hidden_size
:
int
=
32
,
sequence_length
:
int
=
512
,
batch_size
:
int
=
4
,
ffn_hidden_size
:
int
=
64
,
layernorm_epsilon
:
float
=
1e-5
,
)
->
None
:
"""
LayerNorm/RMSNorm + Linear + GELU + Linear
Note that this test checks only if the module runs
as when chaining multiple modules it is hard to validate
numerical accuracy.
"""
# Make input shape
in_shape
=
(
sequence_length
,
batch_size
,
hidden_size
)
ffn_shape
=
in_shape
[:
-
1
]
+
(
ffn_hidden_size
,)
# Skip invalid configurations
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
ffn_shape
,
device
=
device
)
quantization_needed
=
quantized_compute
or
quantized_weight
if
quantization
is
None
and
quantization_needed
:
pytest
.
skip
(
"Quantization scheme is not specified"
)
if
quantization
is
not
None
and
not
quantization_needed
:
pytest
.
skip
(
"Quantization scheme is not used"
)
# Random data
_
,
x_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
)
_
,
dy_test
=
make_reference_and_test_tensors
(
in_shape
,
quantization
=
quantization
,
test_dtype
=
dtype
,
test_device
=
device
,
requires_grad
=
False
,
)
# Implementation with fusible operations
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
if
normalization
==
"LayerNorm"
:
norm
=
te_ops
.
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
,
device
=
device
,
dtype
=
dtype
,
)
else
:
norm
=
te_ops
.
RMSNorm
(
hidden_size
,
eps
=
layernorm_epsilon
,
device
=
device
,
dtype
=
dtype
,
)
ffn1
=
te_ops
.
Linear
(
hidden_size
,
ffn_hidden_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
act
=
te_ops
.
GELU
()
ffn2
=
te_ops
.
Linear
(
ffn_hidden_size
,
hidden_size
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
forward
=
te_ops
.
Sequential
(
norm
,
ffn1
,
act
,
ffn2
)
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
y_test
=
forward
(
x_test
)
y_test
.
backward
(
dy_test
)
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment