Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3105 additions
and
733 deletions
+3105
-733
tests/pytorch/references/quantize_scale_calc.py
tests/pytorch/references/quantize_scale_calc.py
+60
-0
tests/pytorch/references/ref_per_tensor_cs.py
tests/pytorch/references/ref_per_tensor_cs.py
+4
-55
tests/pytorch/test_cpu_offloading.py
tests/pytorch/test_cpu_offloading.py
+101
-42
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+8
-0
tests/pytorch/test_float8_blockwise_gemm_exact.py
tests/pytorch/test_float8_blockwise_gemm_exact.py
+972
-0
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+493
-0
tests/pytorch/test_float8_current_scaling_exact.py
tests/pytorch/test_float8_current_scaling_exact.py
+15
-7
tests/pytorch/test_float8blockwisetensor.py
tests/pytorch/test_float8blockwisetensor.py
+463
-0
tests/pytorch/test_float8tensor.py
tests/pytorch/test_float8tensor.py
+27
-1
tests/pytorch/test_fused_optimizer.py
tests/pytorch/test_fused_optimizer.py
+28
-0
tests/pytorch/test_fused_rope.py
tests/pytorch/test_fused_rope.py
+58
-78
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+129
-6
tests/pytorch/test_multi_tensor.py
tests/pytorch/test_multi_tensor.py
+7
-5
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+248
-44
tests/pytorch/test_permutation.py
tests/pytorch/test_permutation.py
+177
-274
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+63
-4
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+18
-0
transformer_engine/common/activation/activation_template.h
transformer_engine/common/activation/activation_template.h
+4
-4
transformer_engine/common/common.h
transformer_engine/common/common.h
+60
-4
transformer_engine/common/fused_rope/fused_rope.cu
transformer_engine/common/fused_rope/fused_rope.cu
+170
-209
No files found.
tests/pytorch/references/quantize_scale_calc.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Tuple
import
torch
def
scale_from_amax_tensor
(
x_dtype
:
torch
.
dtype
,
amax
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
*
,
eps
:
float
,
pow_2_scales
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Derives quantization and dequantization from amax and options.
Reference implementation for scale calculation.
Returns:
- scale: quantization scales
- scale_inv: dequantization scales
- amax: Amax tensor with updates made for extrema values.
"""
assert
amax
.
dtype
==
torch
.
float
,
"amax must be a float tensor."
fp8_max
=
torch
.
finfo
(
quant_dtype
).
max
# Clamping amax to avoid division by small numbers
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
# Compute scale factor
scale
=
torch
.
div
(
fp8_max
,
amax
)
# Note frexp doesn't give back inf for exponent with an inf input
# We take care of inf before pow_2_scales
scale
=
torch
.
where
(
scale
==
torch
.
inf
,
torch
.
finfo
(
x_dtype
).
max
,
scale
)
if
pow_2_scales
:
# Calculate rounded down exponent
_
,
exp
=
torch
.
frexp
(
scale
)
# Positive numbers are always returned as mant, exp with
# a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
# hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
# of the shift. Subnormal and zero cases need not be considered because
# the smallest possible result of fp8_max / amax is still normal.
exp
=
exp
-
1
# No subnormals and zero.
assert
(
exp
>
-
127
).
all
()
unity
=
torch
.
tensor
([
1.0
],
device
=
exp
.
device
)
torch
.
ldexp
(
unity
,
exp
,
out
=
scale
)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
scale
=
torch
.
where
(
amax
==
float
(
"inf"
),
0.0
,
scale
)
# Handle overflow cases for amax zero causing NaN
scale
=
torch
.
where
(
amax
==
0
,
1.0
,
scale
)
# Compute scale_inv
scale_inv
=
torch
.
reciprocal
(
scale
)
return
scale
,
scale_inv
,
amax
tests/pytorch/references/ref_per_tensor_cs.py
View file @
ab3e5a92
...
@@ -6,63 +6,16 @@ import torch
...
@@ -6,63 +6,16 @@ import torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.constants
import
TE_DType_To_Torch
from
transformer_engine.pytorch.constants
import
TE_DType_To_Torch
from
references.quantize_scale_calc
import
scale_from_amax_tensor
# Compute scale and scale_inv from amax
def
_ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
):
# Clamping amax to avoid division by small numbers
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
# Compute scale factor
scale
=
torch
.
div
(
fp8_max
,
amax
)
# Note frexp doesn't give back inf for exponent with an inf input
# We take care of inf before pow_2_scales
# option1: set scale to fp32 max when scale is inf
scale
=
torch
.
where
(
scale
==
torch
.
inf
,
torch
.
finfo
(
torch
.
float32
).
max
,
scale
)
# option2: when scale is inf, set scale to 1
scale
=
torch
.
where
(
scale
==
torch
.
inf
,
1.0
,
scale
)
if
pow_2_scales
:
# Calculate rounded down exponent
_
,
exp
=
torch
.
frexp
(
scale
)
# Positive numbers are always returned as mant, exp with
# a mantissa in [0.5, 1.0). Because a normal float has a mantissa with
# hidden bit in [1.0, 2.0), the exponent will be off by exactly one because
# of the shift. Subnormal and zero cases need not be considered because
# the smallest possible result of fp8_max / amax is still normal.
exp
=
exp
-
1
# No subnormals and zero.
assert
(
exp
>
-
127
).
all
()
# TODO: If/when adding a URM option an option is to cap to 126
# rather than allowing the full range of FP32 (2 - 2^23) x 2^127
# addresses cases where adding a mantissa overflows into inf scales.
# Not necessary currently without additional scale smudging options.
unity
=
torch
.
tensor
([
1.0
],
device
=
exp
.
device
)
torch
.
ldexp
(
unity
,
exp
,
out
=
scale
)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
# calculation.
scale
=
torch
.
where
(
amax
==
float
(
"inf"
),
0.0
,
scale
)
# Handle overflow cases for amax zero causing NaN
scale
=
torch
.
where
(
amax
==
0
,
1.0
,
scale
)
# Compute scale_inv
scale_inv
=
torch
.
reciprocal
(
scale
)
return
scale
,
scale_inv
# compute amax and scale
# compute amax and scale
def
_ref_compute_amax_scale
(
x
,
quant_dtype
,
eps
,
pow_2_scales
):
def
_ref_compute_amax_scale
(
x
,
quant_dtype
,
eps
,
pow_2_scales
):
x_fp32
=
x
.
to
(
torch
.
float32
)
x_fp32
=
x
.
to
(
torch
.
float32
)
amax
=
torch
.
amax
(
torch
.
abs
(
x_fp32
)).
view
(
1
)
amax
=
torch
.
amax
(
torch
.
abs
(
x_fp32
)).
view
(
1
)
assert
amax
.
dtype
==
torch
.
float
,
"amax must be a float tensor."
return
scale_from_amax_tensor
(
fp8_max
=
torch
.
finfo
(
quant_dtype
).
max
torch
.
float32
,
amax
,
quant_dtype
,
eps
=
eps
,
pow_2_scales
=
pow_2_scales
)
scale
,
scale_inv
=
_ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
)
# Clamping amax to avoid division by small numbers
amax
=
torch
.
max
(
amax
,
torch
.
tensor
(
eps
))
return
scale
,
scale_inv
,
amax
def
_multi_dim_transpose
(
tensor
):
def
_multi_dim_transpose
(
tensor
):
...
@@ -113,7 +66,3 @@ def ref_per_tensor_cs_cast(
...
@@ -113,7 +66,3 @@ def ref_per_tensor_cs_cast(
qx_t
=
_multi_dim_transpose
(
qx
)
qx_t
=
_multi_dim_transpose
(
qx
)
sx_t
=
sx
sx_t
=
sx
return
qx
,
sx
,
qx_t
,
sx_t
return
qx
,
sx
,
qx_t
,
sx_t
def
ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
):
return
_ref_compute_scale_and_scale_inv_from_amax
(
amax
,
fp8_max
,
eps
,
pow_2_scales
)
tests/pytorch/test_cpu_offloading.py
View file @
ab3e5a92
...
@@ -2,41 +2,84 @@
...
@@ -2,41 +2,84 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
import
os
from
contextlib
import
nullcontext
import
pytest
import
pytest
import
torch
import
torch
from
contextlib
import
nullcontext
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
# Check if FP8 supported
# Check if FP8
is
supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_recipes
=
[
None
,
# non-fp8
# recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet
recipe
.
Float8CurrentScaling
(),
recipe
.
DelayedScaling
(),
]
SIZE
=
512
SIZE
=
512
NUM_HEADS
=
8
NUM_LAYERS
=
5
EPSILON
=
0.1
# Flash attention saves some internal tensor for the backward pass
# that cannot be offloaded to CPU.
assert
os
.
getenv
(
"NVTE_FLASH_ATTN"
)
==
"0"
models
=
{
# Offloading is supported for attention only for fused and flash attention backends,
"linear"
:
te
.
Linear
,
# so the use of bfloat16 is required.
"layernorm_mlp"
:
te
.
LayerNormMLP
,
#
"layernorm_linear"
:
te
.
LayerNormLinear
,
# For the TransformerLayer, activation offloading with dropout is not supported,
# so we set hidden_dropout to 0.0.
model_types
=
{
"linear"
:
lambda
:
te
.
Linear
(
SIZE
,
SIZE
,
params_dtype
=
torch
.
bfloat16
),
"layernorm_mlp"
:
lambda
:
te
.
LayerNormMLP
(
SIZE
,
SIZE
,
params_dtype
=
torch
.
bfloat16
),
"layernorm_linear"
:
lambda
:
te
.
LayerNormLinear
(
SIZE
,
SIZE
,
params_dtype
=
torch
.
bfloat16
),
"multihead_attention"
:
lambda
:
te
.
MultiheadAttention
(
SIZE
,
NUM_HEADS
,
params_dtype
=
torch
.
bfloat16
),
"transformer_layer"
:
lambda
:
te
.
TransformerLayer
(
SIZE
,
SIZE
,
NUM_HEADS
,
params_dtype
=
torch
.
bfloat16
,
hidden_dropout
=
0.0
),
}
}
def
_get_input
():
def
_get_input
():
return
torch
.
empty
((
128
,
SIZE
,
SIZE
)).
cuda
()
return
torch
.
empty
((
128
,
SIZE
,
SIZE
),
dtype
=
torch
.
bfloat16
).
cuda
()
def
_get_fp8_weight_cache_size
(
models
,
fp8_recipe
):
"""
Calculate the total FP8 weight cache size (in MB) for a list of models.
"""
if
fp8_recipe
is
None
:
return
0
def
_measure_memory_between_forward_and_backward
(
model_cls
,
fp8
,
cpu_offload
):
params_bytes
=
0
for
model
in
models
:
for
name
,
param
in
model
.
named_parameters
():
if
"weight"
in
name
:
params_bytes
+=
param
.
numel
()
input_layer
=
model_cls
(
SIZE
,
SIZE
)
# One byte for columnwise and one byte for rowwise,
hidden_layer
=
model_cls
(
SIZE
,
SIZE
)
# hence multiply by 2 and convert to MB
output_layer
=
model_cls
(
SIZE
,
SIZE
)
# there is 1 byte of scale per 32 elements in mxFP8
factor_for_scale_inv_tensor
=
(
1
+
1
/
32
)
if
fp8_recipe
.
mxfp8
()
else
1
return
(
2
*
params_bytes
*
factor_for_scale_inv_tensor
)
/
(
1024
**
2
)
input
=
_get_input
()
def
_measure_memory_between_forward_and_backward
(
models
,
fp8_recipe
,
cpu_offload
):
tensor
=
_get_input
()
if
cpu_offload
:
if
cpu_offload
:
offload_context
,
sync_function
=
te
.
get_cpu_offload_context
(
offload_context
,
sync_function
=
te
.
get_cpu_offload_context
(
enabled
=
True
,
enabled
=
True
,
num_layers
=
2
,
num_layers
=
len
(
models
)
-
1
,
model_layers
=
3
,
model_layers
=
len
(
models
)
,
offload_activations
=
True
,
offload_activations
=
True
,
offload_weights
=
False
,
offload_weights
=
False
,
)
)
...
@@ -44,42 +87,58 @@ def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload):
...
@@ -44,42 +87,58 @@ def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload):
offload_context
=
nullcontext
()
offload_context
=
nullcontext
()
sync_function
=
lambda
x
:
x
sync_function
=
lambda
x
:
x
with
te
.
fp8_autocast
(
enabled
=
fp8
),
offload_context
:
for
model
in
models
:
out
=
input_layer
(
input
)
with
te
.
fp8_autocast
(
out
=
sync_function
(
out
)
enabled
=
fp8_recipe
is
not
None
,
fp8_recipe
=
fp8_recipe
with
te
.
fp8_autocast
(
enabled
=
fp8
),
offload_context
:
),
offload_context
:
out
=
hidden_layer
(
out
)
tensor
=
model
(
tensor
)
out
=
sync_function
(
out
)
tensor
=
sync_function
(
tensor
)
with
te
.
fp8_autocast
(
enabled
=
fp8
),
offload_context
:
out
=
output_layer
(
out
)
out
=
sync_function
(
out
)
max_mem_used
=
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
out
.
sum
().
backward
()
del
input_layer
del
hidden_layer
del
output_layer
del
input
del
out
max_mem_used
=
torch
.
cuda
.
memory_allocated
()
/
(
1024
**
2
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
return
max_mem_used
return
max_mem_used
@
pytest
.
mark
.
parametrize
(
"fp8"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"fp8_recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"model_key"
,
models
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model_key"
,
model_types
.
keys
())
def
test_cpu_offload
(
fp8
,
model_key
)
->
None
:
def
test_cpu_offload
(
fp8_recipe
,
model_key
)
->
None
:
"""
We run three configurations:
(1) No offloading: All activations remain on the GPU between forward and backward passes.
(2) No offloading (one layer): Only the first layer's activations remain on the GPU between
forward and backward passes.
(3) With offloading (all layers): Only the last layer's activations remain on the GPU
between forward and backward passes, while all other layers are offloaded to the CPU.
if
fp8
and
not
fp8_available
:
We expect the memory consumption of configurations (2) and (3) to be similar, with
pytest
.
skip
(
reason_for_no_fp8
)
the difference being the size of the FP8 cache that is not offloaded to the CPU.
We also expect this memory consumption to be smaller than in scenario (1).
"""
model_cls
=
models
[
model_key
]
model_cls
=
model_types
[
model_key
]
models_list
=
[
model_cls
()
for
_
in
range
(
NUM_LAYERS
)]
without_offloading
=
_measure_memory_between_forward_and_backward
(
model_cls
,
fp8
,
False
)
if
fp8_recipe
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
is
not
None
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
with_offloading
=
_measure_memory_between_forward_and_backward
(
model_cls
,
fp8
,
True
)
without_offloading
=
_measure_memory_between_forward_and_backward
(
models_list
,
fp8_recipe
,
False
)
without_offloading_one_layer
=
_measure_memory_between_forward_and_backward
(
models_list
[:
1
],
fp8_recipe
,
False
)
with_offloading
=
_measure_memory_between_forward_and_backward
(
models_list
,
fp8_recipe
,
True
)
assert
with_offloading
<
without_offloading
assert
with_offloading
<
without_offloading
# The only difference between the memory consumption of with_offloading
# and without_offloading_one_layer should be the size of the FP8 weights cache,
# which is not offloaded to the CPU.
memory_consumption_diff
=
abs
(
with_offloading
-
without_offloading_one_layer
)
assert
(
memory_consumption_diff
<
_get_fp8_weight_cache_size
(
models_list
[
1
:],
fp8_recipe
)
+
EPSILON
)
tests/pytorch/test_cuda_graphs.py
View file @
ab3e5a92
...
@@ -30,6 +30,9 @@ if IS_HIP_EXTENSION:
...
@@ -30,6 +30,9 @@ if IS_HIP_EXTENSION:
# Check if FP8 is supported.
# Check if FP8 is supported.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
...
@@ -58,6 +61,7 @@ fp8_recipes = [
...
@@ -58,6 +61,7 @@ fp8_recipes = [
recipe
.
DelayedScaling
(),
recipe
.
DelayedScaling
(),
recipe
.
MXFP8BlockScaling
(),
recipe
.
MXFP8BlockScaling
(),
recipe
.
Float8CurrentScaling
(),
recipe
.
Float8CurrentScaling
(),
recipe
.
Float8BlockScaling
(),
]
]
# Supported data types
# Supported data types
...
@@ -328,9 +332,13 @@ def test_make_graphed_callables(
...
@@ -328,9 +332,13 @@ def test_make_graphed_callables(
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
if
fp8_weight_caching
and
not
fp8
:
if
fp8_weight_caching
and
not
fp8
:
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
pytest
.
skip
(
"FP8 needed for FP8 parameters."
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
module
==
"linear_op"
:
pytest
.
skip
(
"Module not yet supported for float8_block_scaling with CUDA graphs"
)
# Run model with different CUDA graph settings.
# Run model with different CUDA graph settings.
model_config
=
model_configs
[
model_config
]
model_config
=
model_configs
[
model_config
]
kwargs
=
dict
(
kwargs
=
dict
(
...
...
tests/pytorch/test_float8_blockwise_gemm_exact.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
)
from
references.blockwise_quantizer_reference
import
CuBLASScaleMunger
from
references.blockwise_fp8_gemm_reference
import
CuBLASRefBlockwiseGemm
def
fp8_blockwise_gemm_supported
()
->
bool
:
supported
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
return
supported
def
cublas_gemm_fp8_blockwise_case
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
*
,
x_columnwise
:
bool
=
False
,
w_columnwise
:
bool
=
False
,
use_bias
:
bool
=
False
,
use_gelu
:
bool
=
False
,
use_grad
:
bool
=
False
,
atol
:
float
=
0.0
,
rtol
:
float
=
0.0
):
if
x_dtype
==
torch
.
float8_e5m2
and
w_dtype
==
torch
.
float8_e5m2
:
pytest
.
skip
(
"FP8 GEMM doesn't support both a and b types being torch.float8_e5m2"
)
if
not
(
is_x_1d_scaled
or
is_w_1d_scaled
):
pytest
.
skip
(
"FP8 GEMM doesn't support 2dimensional qtile by 2dimensional qtile"
)
if
not
fp8_blockwise_gemm_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8 gemm."
)
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x_shape
=
(
K
,
M
)
if
x_columnwise
else
(
M
,
K
)
w_shape
=
(
K
,
N
)
if
w_columnwise
else
(
N
,
K
)
# generate random input and weight
if
noise_type
==
"uniform"
:
x
=
torch
.
rand
(
x_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
*
x_magnitude
*
2
-
x_magnitude
w
=
torch
.
rand
(
w_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
*
w_magnitude
*
2
-
w_magnitude
elif
noise_type
==
"normal"
:
x
=
torch
.
randn
(
x_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
*
x_magnitude
w
=
torch
.
randn
(
w_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
*
w_magnitude
else
:
assert
False
# Setup out tensor if accumulate is True
if
accumulate
:
out
=
torch
.
randn
((
M
,
N
),
dtype
=
out_dtype
,
device
=
device
)
*
x_magnitude
else
:
out
=
None
assert
not
(
use_bias
and
use_grad
),
"Bias grad not supported by GEMM"
# Set quantize_op and quantization parameters
x_quant_tile_shape
=
(
1
,
128
)
if
is_x_1d_scaled
else
(
128
,
128
)
w_quant_tile_shape
=
(
1
,
128
)
if
is_w_1d_scaled
else
(
128
,
128
)
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
x_te_dtype
=
TE_DType
[
x_dtype
]
w_te_dtype
=
TE_DType
[
w_dtype
]
x_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
x_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
x_block_scaling_dim
,
)
w_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
w_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
w_block_scaling_dim
,
)
# Quantize x and w
qx
=
x_quantizer
.
make_empty
(
x_shape
,
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
qx
=
x_quantizer
.
update_quantized
(
x
,
qx
)
qw
=
w_quantizer
.
make_empty
(
w_shape
,
dtype
=
w_dtype
,
device
=
device
,
requires_grad
=
False
)
qw
=
w_quantizer
.
update_quantized
(
w
,
qw
)
if
not
use_bias
:
bias
=
None
else
:
bias
=
torch
.
randn
((
1
,
N
),
dtype
=
torch
.
bfloat16
,
device
=
device
)
# Reference GEMM
ref_gemm
=
CuBLASRefBlockwiseGemm
()
scale_decoder
=
CuBLASScaleMunger
()
qx_data
=
(
qx
.
_columnwise_data
.
view
(
dtype
=
x_dtype
)
if
x_columnwise
else
qx
.
_rowwise_data
.
view
(
dtype
=
x_dtype
)
)
qw_data
=
(
qw
.
_columnwise_data
.
view
(
dtype
=
w_dtype
)
if
w_columnwise
else
qw
.
_rowwise_data
.
view
(
dtype
=
w_dtype
)
)
ref_scales_x
=
qx
.
_columnwise_scale_inv
if
x_columnwise
else
qx
.
_rowwise_scale_inv
ref_scales_w
=
qw
.
_columnwise_scale_inv
if
w_columnwise
else
qw
.
_rowwise_scale_inv
y_ref
=
ref_gemm
.
qgemm
(
qx
=
qx_data
,
qw
=
qw_data
,
out_dtype
=
out_dtype
,
demunged_sx
=
CuBLASScaleMunger
.
demunge_scale_shape_from_backend
(
qtensor_shape
=
(
M
,
K
),
scales
=
ref_scales_x
,
tile_shape
=
x_quant_tile_shape
),
demunged_sw
=
CuBLASScaleMunger
.
demunge_scale_shape_from_backend
(
qtensor_shape
=
(
N
,
K
),
scales
=
ref_scales_w
,
tile_shape
=
w_quant_tile_shape
),
quant_tile_shape_x
=
x_quant_tile_shape
,
quant_tile_shape_w
=
w_quant_tile_shape
,
bias
=
bias
,
out
=
out
.
clone
()
if
accumulate
else
None
,
accumulate
=
accumulate
,
use_split_accumulator
=
use_split_accumulator
,
)
# Allocate cuBLAS workspace
workspace_size
=
0
workspace
=
torch
.
empty
(
0
,
dtype
=
torch
.
uint8
,
device
=
device
)
transa
=
True
if
not
w_columnwise
else
False
transb
=
False
if
not
x_columnwise
else
True
out_quantizer
=
None
assert
not
(
use_gelu
and
use_bias
),
"Bias and GELU not supported by GEMM"
aux_tensor
=
torch
.
randn
((
M
,
N
),
dtype
=
out_dtype
,
device
=
device
)
if
use_gelu
else
None
aux_tensor_ref
=
aux_tensor
.
clone
()
if
use_gelu
else
None
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
y
=
tex
.
generic_gemm
(
qw
,
transa
,
qx
,
transb
,
out
.
clone
()
if
accumulate
else
None
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
aux_tensor
,
use_grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)[
0
]
# just in case of accumulation, make sure y_ref and y are not the same tensor
assert
y_ref
is
not
y
,
"y_ref and y should not be the same tensor"
# Reset nans to zeros because torch.assert_close does not assume nans to be equal
assert
not
torch
.
isnan
(
y_ref
.
float
()).
all
(),
"All elements are nan"
y_ref
=
torch
.
where
(
y_ref
.
isnan
(),
torch
.
zeros_like
(
y_ref
),
y_ref
)
y
=
torch
.
where
(
y
.
isnan
(),
torch
.
zeros_like
(
y
),
y
)
if
use_gelu
:
# Check
if
use_grad
:
# With use_grad, GEMM should use aux tensor to calculate
# gradient
gelu_ref
=
tex
.
dgelu
(
y_ref
,
aux_tensor_ref
,
None
)
# TODO: How do we decide whether this is acceptably close?
# Could also try to put the activation inside the reference
# before the output cast to see different tolerances.
torch
.
testing
.
assert_close
(
y
,
gelu_ref
,
atol
=
1e-3
,
rtol
=
1e-2
)
else
:
# aux tensor is pre-gelu aux output. Verify against y_ref.
torch
.
testing
.
assert_close
(
aux_tensor
,
y_ref
,
atol
=
atol
,
rtol
=
rtol
)
act
=
torch
.
nn
.
GELU
()
gelu_ref
=
act
(
y_ref
)
# gelu_ref = tex.gelu(y_ref, None)
torch
.
testing
.
assert_close
(
y
,
gelu_ref
,
atol
=
atol
,
rtol
=
rtol
)
else
:
torch
.
testing
.
assert_close
(
y
,
y_ref
,
atol
=
atol
,
rtol
=
rtol
)
def
cublas_gemm_test_constraint_enforced
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
*
,
x_columnwise
:
bool
=
False
,
w_columnwise
:
bool
=
False
,
use_bias
:
bool
=
False
,
use_gelu
:
bool
=
False
,
use_grad
:
bool
=
False
,
expected_err_msg
=
"CUBLAS_STATUS_NOT_SUPPORTED"
,
expected_err_cls
=
RuntimeError
):
if
not
fp8_blockwise_gemm_supported
():
pytest
.
skip
(
"CUDA version does not support blockwise FP8 gemm."
)
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x_shape
=
(
K
,
M
)
if
x_columnwise
else
(
M
,
K
)
w_shape
=
(
K
,
N
)
if
w_columnwise
else
(
N
,
K
)
# generate random input and weight
x
=
torch
.
rand
(
x_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
*
2.0
-
1.0
w
=
torch
.
rand
(
w_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
*
2.0
-
1.0
# Setup out tensor if accumulate is True
if
accumulate
:
out
=
torch
.
randn
((
M
,
N
),
dtype
=
out_dtype
,
device
=
device
)
else
:
out
=
None
# Set quantize_op and quantization parameters
x_quant_tile_shape
=
(
1
,
128
)
if
is_x_1d_scaled
else
(
128
,
128
)
w_quant_tile_shape
=
(
1
,
128
)
if
is_w_1d_scaled
else
(
128
,
128
)
x_block_scaling_dim
=
1
if
is_x_1d_scaled
else
2
w_block_scaling_dim
=
1
if
is_w_1d_scaled
else
2
x_te_dtype
=
TE_DType
[
x_dtype
]
w_te_dtype
=
TE_DType
[
w_dtype
]
x_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
x_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
x_block_scaling_dim
,
)
w_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
w_te_dtype
,
rowwise
=
True
,
columnwise
=
True
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
block_scaling_dim
=
w_block_scaling_dim
,
)
# Quantize x and w
qx
=
x_quantizer
.
make_empty
(
x_shape
,
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
qx
=
x_quantizer
.
update_quantized
(
x
,
qx
)
qw
=
w_quantizer
.
make_empty
(
w_shape
,
dtype
=
w_dtype
,
device
=
device
,
requires_grad
=
False
)
qw
=
w_quantizer
.
update_quantized
(
w
,
qw
)
if
not
use_bias
:
bias
=
None
else
:
bias
=
torch
.
randn
((
1
,
N
),
dtype
=
torch
.
bfloat16
,
device
=
device
)
# Allocate cuBLAS workspace
workspace_size
=
0
workspace
=
torch
.
empty
(
0
,
dtype
=
torch
.
uint8
,
device
=
device
)
transa
=
True
if
not
w_columnwise
else
False
transb
=
False
if
not
x_columnwise
else
True
out_quantizer
=
None
grad
=
use_grad
gelu_in
=
None
if
not
use_gelu
else
torch
.
randn
((
M
,
N
),
dtype
=
out_dtype
,
device
=
device
)
bias_dtype
=
TE_DType
[
torch
.
bfloat16
if
bias
is
None
else
bias
.
dtype
]
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
# We are just capturing out.
with
pytest
.
raises
(
expected_err_cls
,
match
=
expected_err_msg
):
y
=
tex
.
generic_gemm
(
qw
,
transa
,
qx
,
transb
,
out
.
clone
()
if
accumulate
else
None
,
out_quantizer
,
TE_DType
[
out_dtype
],
bias
,
bias_dtype
,
use_gelu
,
gelu_in
,
grad
,
workspace
,
workspace
.
shape
[
0
],
accumulate
,
use_split_accumulator
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
# k = 128
(
128
,
128
,
128
),
(
256
,
128
,
256
),
# non 128x128 divisible input shapes
(
320
,
128
,
336
),
(
320
,
64
,
336
),
# k > 128
(
256
,
256
,
256
),
(
320
,
256
,
336
),
(
1024
,
4096
,
1024
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
False
],
ids
=
[
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
True
,
False
),
(
True
,
True
),
(
False
,
True
),
],
ids
=
[
"1Dx2D"
,
"1Dx1D"
,
"2Dx1D"
],
)
def
test_cublas_gemm_fp8_blockwise_shape_varying
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
):
cublas_gemm_fp8_blockwise_case
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
(
256
,
128
,
256
),
(
320
,
256
,
336
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
,
"uniform"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1e-28
,
1
,
1e3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
True
,
False
],
ids
=
[
"accumulate"
,
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
True
,
False
),
(
True
,
True
),
(
False
,
True
),
],
ids
=
[
"1Dx2D"
,
"1Dx1D"
,
"2Dx1D"
],
)
def
test_cublas_gemm_fp8_blockwise_accumulate_magnitude_varying
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
):
cublas_gemm_fp8_blockwise_case
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
# k = 128
(
256
,
128
,
256
),
# non 128x128 divisible input shapes
(
320
,
64
,
336
),
# k > 128
(
256
,
256
,
256
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1e-3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
True
,
False
],
ids
=
[
"accumulate"
,
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
True
,
False
),
(
True
,
True
),
(
False
,
True
),
],
ids
=
[
"1Dx2D"
,
"1Dx1D"
,
"2Dx1D"
],
)
def
test_cublas_gemm_fp8_blockwise_bias
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
):
cublas_gemm_fp8_blockwise_case
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
use_bias
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
# k = 128
(
256
,
128
,
256
),
# non 128x128 divisible input shapes
(
16
,
128
,
128
),
(
320
,
64
,
336
),
# k > 128
(
4096
,
128
,
4096
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
True
,
False
],
ids
=
[
"accumulate"
,
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
True
,
False
),
(
True
,
True
),
(
False
,
True
),
],
ids
=
[
"1Dx2D"
,
"1Dx1D"
,
"2Dx1D"
],
)
@
pytest
.
mark
.
parametrize
(
"is_x_columnwise, is_w_columnwise"
,
[
(
True
,
False
),
(
True
,
True
),
(
False
,
True
),
],
ids
=
[
"colxrow"
,
"colxcol"
,
"rowxcol"
],
)
def
test_cublas_gemm_fp8_blockwise_columnwise
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
is_x_columnwise
,
is_w_columnwise
,
):
cublas_gemm_fp8_blockwise_case
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
x_columnwise
=
is_x_columnwise
,
w_columnwise
=
is_w_columnwise
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
# k = 128
(
256
,
128
,
256
),
# non 128x128 divisible input shapes
(
320
,
64
,
336
),
# k > 128
(
256
,
256
,
256
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"noise_type"
,
[
"normal"
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"x_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_magnitude"
,
[
1
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
True
,
False
],
ids
=
[
"accumulate"
,
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
True
,
False
),
(
True
,
True
),
(
False
,
True
),
],
ids
=
[
"1Dx2D"
,
"1Dx1D"
,
"2Dx1D"
],
)
@
pytest
.
mark
.
parametrize
(
"use_grad"
,
[
True
,
],
ids
=
[
"grad"
],
)
def
test_cublas_gemm_fp8_gelu
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
use_grad
,
):
# NOTE: cuBLAS doesn't complain with not use_grad, but the tests don't succeed
# so the epilogue is disabled on the transformer engine side.
if
not
use_grad
and
not
(
is_x_1d_scaled
and
not
is_w_1d_scaled
):
pytest
.
skip
(
"CUBLASLT_EPILOGUE_GELU_AUX epilogue is only supported for 1Dx2D (cuBLAS 2Dx1D)."
)
cublas_gemm_fp8_blockwise_case
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
noise_type
,
x_magnitude
,
w_magnitude
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
use_gelu
=
True
,
use_grad
=
use_grad
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
# k = 128
(
256
,
128
,
256
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
True
,
False
],
ids
=
[
"accumulate"
,
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
False
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
True
,
False
),
(
True
,
True
),
(
False
,
True
),
],
ids
=
[
"1Dx2D"
,
"1Dx1D"
,
"2Dx1D"
],
)
def
test_split_accumulator_enforced
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
->
None
:
cublas_gemm_test_constraint_enforced
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
# k = 128
(
256
,
128
,
256
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
True
,
False
],
ids
=
[
"accumulate"
,
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
True
,
False
),
(
True
,
True
),
(
False
,
True
),
],
ids
=
[
"1Dx2D"
,
"1Dx1D"
,
"2Dx1D"
],
)
def
test_bgrad_not_supported
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
->
None
:
# NOTE: BGRAD epilogue is not supported for fp8.
cublas_gemm_test_constraint_enforced
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
use_grad
=
True
,
use_bias
=
True
,
expected_err_msg
=
"Epilogue requested outside of the available"
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
# k = 128
(
256
,
128
,
256
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
True
,
False
],
ids
=
[
"accumulate"
,
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
],
ids
=
[
"bias"
,
"no_bias"
])
@
pytest
.
mark
.
parametrize
(
"use_grad"
,
[
True
,
False
],
ids
=
[
"grad"
,
"no_grad"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
True
,
False
),
(
True
,
True
),
(
False
,
True
),
],
ids
=
[
"1Dx2D"
,
"1Dx1D"
,
"2Dx1D"
],
)
def
test_gelu_unsupported_cases_error
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_bias
,
use_grad
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
->
None
:
if
use_grad
and
not
use_bias
and
out_dtype
==
torch
.
bfloat16
:
pytest
.
skip
(
"DGELU epilogue is supported for bfloat16."
)
elif
use_grad
and
not
use_bias
:
expected_err
=
"an unsupported value or parameter was passed"
else
:
expected_err
=
"Epilogue requested outside of the available"
cublas_gemm_test_constraint_enforced
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
use_grad
=
use_grad
,
use_bias
=
use_bias
,
use_gelu
=
True
,
expected_err_msg
=
expected_err
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
(
256
,
128
,
256
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
True
,
False
],
ids
=
[
"accumulate"
,
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
True
,
False
),
(
True
,
True
),
(
False
,
True
),
],
ids
=
[
"1Dx2D"
,
"1Dx1D"
,
"2Dx1D"
],
)
def
test_illegal_dtype_enforced
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
->
None
:
# e5m2 by e5m2 not supported.
cublas_gemm_test_constraint_enforced
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N"
,
[
(
256
,
128
,
256
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float32
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
True
,
False
],
ids
=
[
"accumulate"
,
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
False
,
False
),
],
ids
=
[
"2Dx2D"
],
)
def
test_illegal_2D_by_2D_enforced
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
->
None
:
# 2D block quantization by 2D block quantization is not supported.
expected_err_msg
=
"Only 1D by 1D, 1D by 2D, and 2D by 1D block scaling supported"
cublas_gemm_test_constraint_enforced
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
expected_err_msg
=
expected_err_msg
,
)
@
pytest
.
mark
.
parametrize
(
"M, K, N, legalX1d, legalX2d"
,
[
# M dim unconstrained when X is 2D.
(
255
,
128
,
256
,
False
,
True
),
# K must be multiple of 16
(
256
,
120
,
256
,
False
,
False
),
# N must be a multiple of 8
(
256
,
128
,
252
,
False
,
False
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"w_dtype"
,
[
torch
.
float8_e4m3fn
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"accumulate"
,
[
False
],
ids
=
[
"no_accumulate"
])
@
pytest
.
mark
.
parametrize
(
"use_split_accumulator"
,
[
True
],
ids
=
[
"split_acc"
])
@
pytest
.
mark
.
parametrize
(
"is_x_1d_scaled, is_w_1d_scaled"
,
[
(
True
,
False
),
(
False
,
True
),
(
True
,
True
),
],
ids
=
[
"1Dx2D"
,
"2Dx1D"
,
"1Dx1D"
],
)
def
test_unaligned_shapes
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
legalX1d
,
legalX2d
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
->
None
:
legal
=
legalX1d
if
is_x_1d_scaled
else
legalX2d
if
not
legal
:
cublas_gemm_test_constraint_enforced
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
expected_err_msg
=
"dimension requirement"
,
)
else
:
cublas_gemm_fp8_blockwise_case
(
x_dtype
,
w_dtype
,
out_dtype
,
M
,
K
,
N
,
"uniform"
,
# noise type
1.0
,
# x_magnitude
1.0
,
# w_magnitude
accumulate
,
use_split_accumulator
,
is_x_1d_scaled
,
is_w_1d_scaled
,
)
tests/pytorch/test_float8_blockwise_scaling_exact.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Tuple
import
math
import
os
import
pathlib
import
pytest
import
torch
import
transformer_engine
as
te
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.common.recipe
import
Float8BlockScaling
from
transformer_engine.pytorch.constants
import
TE_DType
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
)
from
references.blockwise_quantizer_reference
import
(
BlockwiseQuantizerReference
,
QuantizeResult
,
)
from
test_float8_current_scaling_exact
import
(
TestFP8RecipeLinearBase
,
TestFP8RecipeLayerNormLinearBase
,
)
# read env variable NVTE_TEST_FLOAT8_BLOCK_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory
TENSOR_DUMP_DIR
=
pathlib
.
Path
(
__file__
).
resolve
().
parent
.
parent
.
parent
/
"tensor_dumps"
tensor_dump_dir_env
=
os
.
getenv
(
"NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR"
)
if
tensor_dump_dir_env
is
not
None
:
TENSOR_DUMP_DIR
=
pathlib
.
Path
(
tensor_dump_dir_env
)
recipe_available
,
reason_for_no_recipe
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
class
GetRecipes
:
@
staticmethod
def
none
():
return
None
@
staticmethod
def
fp8_blockwise
():
# return default configs
return
Float8BlockScaling
()
def
initialize_for_many_scales
(
x_shape_2d
:
Tuple
[
int
,
int
],
tile_shape
:
Tuple
[
int
,
int
],
*
,
dtype
:
torch
.
dtype
,
device
:
str
)
->
torch
.
Tensor
:
"""
Put separate distributions into each quantization tile
to avoid many tiles having similar scale values and
causing false passes.
"""
tile_grid_shape
=
(
math
.
ceil
(
x_shape_2d
[
0
]
/
tile_shape
[
0
]),
math
.
ceil
(
x_shape_2d
[
1
]
/
tile_shape
[
1
]),
)
# Arbitrary size
max_val
=
8192.0
# Make a uniform distribution of [-max_val, max_val]
tile_extrema
=
torch
.
rand
(
*
tile_grid_shape
,
dtype
=
dtype
)
*
max_val
*
2
-
max_val
result
=
torch
.
empty
(
x_shape_2d
,
dtype
=
dtype
,
device
=
device
)
tile_elements
=
tile_shape
[
0
]
*
tile_shape
[
1
]
for
i
in
range
(
tile_grid_shape
[
0
]):
for
j
in
range
(
tile_grid_shape
[
1
]):
target
=
tile_extrema
[
i
,
j
].
item
()
step
=
target
/
(
tile_elements
)
if
target
==
0
:
tile
=
torch
.
zeros
(
tile_shape
,
dtype
=
dtype
,
device
=
device
)
else
:
tile
=
torch
.
arange
(
0.0
,
target
,
step
=
step
,
dtype
=
dtype
,
device
=
device
)
tile
=
tile
.
reshape
(
*
tile_shape
)
min_dst_vals
=
(
i
*
tile_shape
[
0
],
j
*
tile_shape
[
1
])
max_dst_vals
=
(
min
((
i
+
1
)
*
tile_shape
[
0
],
x_shape_2d
[
0
]),
min
((
j
+
1
)
*
tile_shape
[
1
],
x_shape_2d
[
1
]),
)
max_src_vals
=
(
max_dst_vals
[
0
]
-
min_dst_vals
[
0
],
max_dst_vals
[
1
]
-
min_dst_vals
[
1
],
)
result
[
min_dst_vals
[
0
]
:
max_dst_vals
[
0
],
min_dst_vals
[
1
]
:
max_dst_vals
[
1
]]
=
tile
[
:
max_src_vals
[
0
],
:
max_src_vals
[
1
]
]
return
result
def
check_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
quant_dtype
:
torch
.
dtype
,
eps
:
float
,
return_transpose
:
bool
,
pow_2_scales
:
bool
,
tile_size
:
Tuple
[
int
,
int
],
)
->
None
:
te_dtype
=
TE_DType
[
quant_dtype
]
if
tile_size
==
(
1
,
128
):
block_scaling_dim
=
1
elif
tile_size
==
(
128
,
128
):
block_scaling_dim
=
2
else
:
raise
ValueError
(
"Non support tile size"
)
# This test runs a comparison of the ref class versus the class using
# CUDA kernels to quantize. They should quantize identically for pixels
# that are not DC values in the scale factor shape.
ref_quantizer
=
BlockwiseQuantizerReference
()
sut_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
return_transpose
,
amax_epsilon
=
eps
,
force_pow_2_scales
=
pow_2_scales
,
block_scaling_dim
=
block_scaling_dim
,
)
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Input
x
=
initialize_for_many_scales
((
M
,
N
),
tile_size
,
dtype
=
x_dtype
,
device
=
device
)
x_fp8_sut
=
sut_quantizer
.
make_empty
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_fp8_sut
=
sut_quantizer
.
update_quantized
(
x
,
x_fp8_sut
)
assert
x_fp8_sut
.
_rowwise_data
is
not
None
qx
:
torch
.
Tensor
=
x_fp8_sut
.
_rowwise_data
.
view
(
dtype
=
quant_dtype
)
assert
x_fp8_sut
.
_rowwise_scale_inv
is
not
None
sx
:
torch
.
Tensor
=
x_fp8_sut
.
_rowwise_scale_inv
qx_t
=
x_fp8_sut
.
_columnwise_data
sx_t
=
x_fp8_sut
.
_columnwise_scale_inv
qresult_ref
=
ref_quantizer
.
quantize
(
x
,
quant_dtype
=
quant_dtype
,
return_transpose
=
return_transpose
,
eps
=
eps
,
pow_2_scales
=
pow_2_scales
,
quant_tile_shape
=
tile_size
,
)
qx_ref
,
sx_ref
,
qx_t_ref
,
sx_t_ref
=
(
qresult_ref
.
data
,
qresult_ref
.
scale
,
qresult_ref
.
data_t
,
qresult_ref
.
scale_t
,
)
# Check
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
# Zero out values that are don't care values
# Scale format has padding.
scale_mask
=
torch
.
ones
(
(
math
.
ceil
(
M
/
tile_size
[
0
]),
math
.
ceil
(
N
/
tile_size
[
1
])),
device
=
sx
.
device
)
scale_mask
=
ref_quantizer
.
scale_munger
.
munge_scale_shapes_for_backend
(
QuantizeResult
(
qx
,
scale_mask
,
None
,
None
),
tile_size
).
scale
sx
=
sx
*
scale_mask
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
assert
qx_t
is
not
None
qx_t
=
qx_t
.
view
(
dtype
=
quant_dtype
)
assert
qx_t_ref
is
not
None
assert
sx_t
is
not
None
assert
sx_t_ref
is
not
None
scale_mask
=
torch
.
ones
(
(
math
.
ceil
(
N
/
tile_size
[
0
]),
math
.
ceil
(
M
/
tile_size
[
1
])),
device
=
sx_t
.
device
,
)
scale_mask
=
ref_quantizer
.
scale_munger
.
munge_scale_shapes_for_backend
(
QuantizeResult
(
qx_t
,
scale_mask
,
None
,
None
),
tile_size
).
scale
sx_t
=
sx_t
*
scale_mask
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
else
:
# should be None
assert
qx_t
is
None
and
qx_t_ref
is
None
assert
sx_t
is
None
and
sx_t_ref
is
None
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# full tile cases
(
128
,
128
),
(
256
,
256
),
(
256
,
1024
),
(
1024
,
256
),
# Padding required cases
(
256
,
272
),
(
303
,
300
),
(
305
,
256
),
# Some larger tiles.
(
2000
,
2000
),
(
2048
,
2000
),
(
2000
,
1024
),
(
2048
,
1024
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"quantize_only"
]
)
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
True
],
ids
=
[
"pow2scales"
])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
1
,
128
),
(
128
,
128
)],
ids
=
[
"1DTile"
,
"2DTile"
])
def
test_quantization_block_tiling_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
quant_dtype
:
torch
.
dtype
,
eps
:
float
,
return_transpose
:
bool
,
pow_2_scales
:
bool
,
tile_size
:
Tuple
[
int
,
int
],
)
->
None
:
check_quantization_block_tiling_versus_reference
(
x_dtype
,
M
,
N
,
quant_dtype
,
eps
,
return_transpose
,
pow_2_scales
,
tile_size
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# full tile cases
(
256
,
256
),
(
2048
,
1024
),
# Padding required cases
(
256
,
272
),
(
303
,
300
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"return_transpose"
,
[
True
,
False
],
ids
=
[
"quantize_transpose"
,
"quantize_only"
]
)
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
False
],
ids
=
[
"fp32scales"
])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
1
,
128
),
(
128
,
128
)],
ids
=
[
"1DTile"
,
"2DTile"
])
def
test_quantization_block_tiling_versus_reference_fp32_scales
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
quant_dtype
:
torch
.
dtype
,
eps
:
float
,
return_transpose
:
bool
,
pow_2_scales
:
bool
,
tile_size
:
Tuple
[
int
,
int
],
)
->
None
:
check_quantization_block_tiling_versus_reference
(
x_dtype
,
M
,
N
,
quant_dtype
,
eps
,
return_transpose
,
pow_2_scales
,
tile_size
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
@
pytest
.
mark
.
parametrize
(
"M, N"
,
[
# full tile cases
(
128
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"x_dtype"
,
[
torch
.
float32
,
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
[
0
],
ids
=
[
"eps_0"
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
True
,
False
],
ids
=
[
"pow2scales"
,
"fp32scales"
])
@
pytest
.
mark
.
parametrize
(
"tile_size"
,
[(
128
,
128
)])
@
pytest
.
mark
.
parametrize
(
"extrema_high"
,
[
False
,
True
],
ids
=
[
"zeros"
,
"maxes"
])
def
test_quantization_block_tiling_extrema_versus_reference
(
x_dtype
:
torch
.
dtype
,
M
:
int
,
N
:
int
,
quant_dtype
:
torch
.
dtype
,
eps
:
float
,
pow_2_scales
:
bool
,
tile_size
:
Tuple
[
int
,
int
],
extrema_high
:
bool
,
)
->
None
:
# This test runs a single tile through a quantizer as a way to test
# branch coverage of scale computation.
te_dtype
=
TE_DType
[
quant_dtype
]
if
tile_size
==
(
1
,
128
):
block_scaling_dim
=
1
elif
tile_size
==
(
128
,
128
):
block_scaling_dim
=
2
else
:
raise
ValueError
(
"Non support tile size"
)
ref_quantizer
=
BlockwiseQuantizerReference
()
sut_quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
False
,
amax_epsilon
=
eps
,
force_pow_2_scales
=
pow_2_scales
,
block_scaling_dim
=
block_scaling_dim
,
)
# Setup device and random seed
device
=
"cuda"
seed
=
0
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
return_transpose
=
False
# Input
if
extrema_high
:
x
=
torch
.
full
((
M
,
N
),
torch
.
finfo
(
x_dtype
).
max
,
dtype
=
x_dtype
,
device
=
device
)
else
:
x
=
torch
.
zeros
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
)
# Run cast and transpose kernel
# Internal call ops.quantize_tensorwise
x_fp8_sut
=
sut_quantizer
.
make_empty
((
M
,
N
),
dtype
=
x_dtype
,
device
=
device
,
requires_grad
=
False
)
x_fp8_sut
=
sut_quantizer
.
update_quantized
(
x
,
x_fp8_sut
)
qx
=
x_fp8_sut
.
_rowwise_data
.
view
(
dtype
=
quant_dtype
)
sx
=
x_fp8_sut
.
_rowwise_scale_inv
qresult_ref
=
ref_quantizer
.
quantize
(
x
,
quant_dtype
=
quant_dtype
,
return_transpose
=
return_transpose
,
eps
=
eps
,
pow_2_scales
=
pow_2_scales
,
quant_tile_shape
=
tile_size
,
)
qx_ref
,
sx_ref
=
(
qresult_ref
.
data
,
qresult_ref
.
scale
,
)
# Check
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx
.
flatten
()[
0
],
sx_ref
.
flatten
()[
0
],
atol
=
0.0
,
rtol
=
0.0
)
if
extrema_high
:
expected_value
=
torch
.
finfo
(
quant_dtype
).
max
/
torch
.
finfo
(
x_dtype
).
max
if
pow_2_scales
:
expected_value
=
math
.
floor
(
math
.
log2
(
expected_value
))
expected_value
=
math
.
pow
(
2.0
,
expected_value
)
expected_value
=
1
/
expected_value
elif
not
extrema_high
and
eps
==
0
:
expected_value
=
1.0
else
:
assert
not
extrema_high
# eps is small enough to trigger inf in quant_dtype_max / eps
if
pow_2_scales
:
expected_value
=
math
.
pow
(
2.0
,
-
127
)
else
:
expected_value
=
1
/
torch
.
finfo
(
x_dtype
).
max
torch
.
testing
.
assert_close
(
sx
.
flatten
()[
0
],
torch
.
tensor
(
expected_value
,
device
=
sx
.
device
),
atol
=
0.0
,
rtol
=
0.0
,
)
# FP8 per tesnor current scaling
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
class
TestFP8BlockScalingRecipeLinear
(
TestFP8RecipeLinearBase
):
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
16
,
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
[
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"recipe1, recipe2"
,
[
(
GetRecipes
.
none
,
GetRecipes
.
fp8_blockwise
),
],
)
def
test_fp8_current_scaling_with_linear_module
(
self
,
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
dtype
,
use_bias
=
True
,
):
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map
=
self
.
_check_golden_tensor_dumps
(
TENSOR_DUMP_DIR
,
recipe2
,
(
batch_size
,
hidden_size
,
out_size
),
dtype
,
use_bias
)
if
tensor_map
is
not
None
:
fp8_zero_tolerance_tensor_dumps_recipe2
=
tensor_map
self
.
compare_recipe
(
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.5
,
dgrad_error
=
1
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
recipe2_golden_tensors
=
fp8_zero_tolerance_tensor_dumps_recipe2
,
)
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
class
TestFP8BlockScalingRecipeLayerNormLinear
(
TestFP8RecipeLayerNormLinearBase
):
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"batch_size, hidden_size, out_size"
,
[
(
16
,
256
,
128
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
[
"bf16"
])
@
pytest
.
mark
.
parametrize
(
"recipe1, recipe2"
,
[
(
GetRecipes
.
none
,
GetRecipes
.
fp8_blockwise
),
],
)
def
test_fp8_current_scaling_with_layernorm_linear_module
(
self
,
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
dtype
,
use_bias
=
True
,
):
fp8_zero_tolerance_tensor_dumps_recipe2
=
None
# check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad
# if we cannot get all four tensors, then still set the tensor dump to None
tensor_map
=
self
.
_check_golden_tensor_dumps
(
TENSOR_DUMP_DIR
,
recipe2
,
(
batch_size
,
hidden_size
,
out_size
),
dtype
,
use_bias
,
"LayerNorm"
,
)
if
tensor_map
is
not
None
:
fp8_zero_tolerance_tensor_dumps_recipe2
=
tensor_map
self
.
compare_recipe
(
recipe1
,
recipe2
,
batch_size
,
hidden_size
,
out_size
,
use_bias
,
seed
=
torch
.
initial_seed
(),
dtype
=
dtype
,
y_error
=
0.5
,
ln_out_error
=
0.5
,
dgrad_error
=
1.6
,
wgrad_error
=
1
,
bgrad_error
=
0.5
,
recipe1_golden_tensors
=
None
,
recipe2_golden_tensors
=
fp8_zero_tolerance_tensor_dumps_recipe2
,
)
tests/pytorch/test_float8_current_scaling_exact.py
View file @
ab3e5a92
...
@@ -82,7 +82,8 @@ class TestFP8RecipeLinearBase:
...
@@ -82,7 +82,8 @@ class TestFP8RecipeLinearBase:
@
staticmethod
@
staticmethod
def
_get_mean_abs_relative_error
(
a
,
b
):
def
_get_mean_abs_relative_error
(
a
,
b
):
return
torch
.
mean
(
torch
.
abs
((
a
-
b
)
/
b
))
error
=
torch
.
where
(
b
==
0
,
torch
.
ne
(
a
,
b
),
torch
.
abs
((
a
-
b
)
/
b
))
return
torch
.
mean
(
error
)
@
staticmethod
@
staticmethod
def
_load_golden_tensor_values
(
a
,
b
):
def
_load_golden_tensor_values
(
a
,
b
):
...
@@ -97,9 +98,12 @@ class TestFP8RecipeLinearBase:
...
@@ -97,9 +98,12 @@ class TestFP8RecipeLinearBase:
fp8_type_g
=
get_fp8_torch_dtype
(
recipe
,
fprop_tensor
=
False
)
fp8_type_g
=
get_fp8_torch_dtype
(
recipe
,
fprop_tensor
=
False
)
# Expected tensor names based on the naming template
# Expected tensor names based on the naming template
scaling_type
=
(
# Assuming the scaling type is PER_TENSOR for this example
if
recipe
.
float8_current_scaling
():
"ScalingType.PER_TENSOR"
scaling_type
=
"ScalingType.PER_TENSOR"
)
elif
recipe
.
float8_block_scaling
():
scaling_type
=
"ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W"
else
:
scaling_type
=
"Unknown"
current_seed
=
torch
.
initial_seed
()
# Get the current seed
current_seed
=
torch
.
initial_seed
()
# Get the current seed
expected_tensor_names
=
{
expected_tensor_names
=
{
...
@@ -437,9 +441,13 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
...
@@ -437,9 +441,13 @@ class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase):
fp8_type_g
=
get_fp8_torch_dtype
(
recipe
,
fprop_tensor
=
False
)
fp8_type_g
=
get_fp8_torch_dtype
(
recipe
,
fprop_tensor
=
False
)
# Expected tensor names based on the naming template
# Expected tensor names based on the naming template
scaling_type
=
(
# Assuming the scaling type is PER_TENSOR for this example
if
recipe
.
float8_current_scaling
():
"ScalingType.PER_TENSOR"
scaling_type
=
"ScalingType.PER_TENSOR"
)
elif
recipe
.
float8_block_scaling
():
scaling_type
=
"ScalingType.VECTOR_TILED_X_AND_G_BLOCK_TILED_W"
else
:
scaling_type
=
"Unknown"
current_seed
=
torch
.
initial_seed
()
# Get the current seed
current_seed
=
torch
.
initial_seed
()
# Get the current seed
expected_tensor_names
=
{
expected_tensor_names
=
{
...
...
tests/pytorch/test_float8blockwisetensor.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
collections.abc
import
Iterable
import
io
import
math
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
import
pytest
import
torch
import
transformer_engine.common.recipe
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
(
Float8BlockQuantizer
,
Float8BlockwiseQTensor
,
)
from
transformer_engine.pytorch.utils
import
get_device_compute_capability
import
transformer_engine_torch
as
tex
# PyTorch tensor dtypes
_dtypes
:
List
[
torch
.
dtype
]
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
# TE FP8 dtypes
_fp8_dtypes
:
List
[
tex
.
DType
]
=
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
]
# Numerical tolerances with FP8 types
_tols
:
Dict
[
tex
.
DType
,
Dict
[
str
,
float
]]
=
{
tex
.
DType
.
kFloat8E4M3
:
dict
(
rtol
=
0.125
,
atol
=
0.08
),
tex
.
DType
.
kFloat8E5M2
:
dict
(
rtol
=
0.25
,
atol
=
0.125
),
}
def
_to_list
(
x
:
Union
[
Iterable
,
Any
])
->
List
:
"""Convert to list if iterable, otherwise put in singleton list"""
if
isinstance
(
x
,
Iterable
):
return
list
(
x
)
else
:
return
[
x
]
# Types that can be interpreted as tensor dims
DimsType
=
Union
[
Iterable
[
int
],
int
]
# TODO replace with call to fp8.py when recipe added.
recipe_available
=
get_device_compute_capability
()
>=
(
9
,
0
)
and
float
(
torch
.
version
.
cuda
)
>=
12.8
reason_for_no_recipe
=
"Quantize kernels require TMA and are only relevant with GEMMS."
@
pytest
.
mark
.
skipif
(
not
recipe_available
,
reason
=
reason_for_no_recipe
)
class
TestFloat8BlockwiseTensor
:
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
def
test_constructor
(
self
,
dims
:
DimsType
=
1
,
fp8_dtype
:
tex
.
DType
=
tex
.
DType
.
kFloat8E4M3
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
is_2D_scaled
:
bool
=
True
,
)
->
None
:
"""Call constructor and perform sanity checks"""
dims
=
_to_list
(
dims
)
rowwise
=
True
columnwise
=
True
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
rowwise
,
columnwise
=
columnwise
,
block_scaling_dim
=
2
if
is_2D_scaled
else
1
,
)
scale_dims
=
quantizer
.
get_scale_shape
(
dims
,
columnwise
=
False
)
columnwise_scale_dims
=
quantizer
.
get_scale_shape
(
dims
,
columnwise
=
True
)
columnwise_dims
=
quantizer
.
get_columnwise_shape
(
dims
)
tensor
=
Float8BlockwiseQTensor
(
shape
=
dims
,
dtype
=
dtype
,
rowwise_data
=
torch
.
zeros
(
dims
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
),
rowwise_scale_inv
=
torch
.
zeros
(
scale_dims
,
device
=
"cuda"
,
dtype
=
torch
.
float32
),
columnwise_data
=
torch
.
zeros
(
columnwise_dims
,
device
=
"cuda"
,
dtype
=
torch
.
uint8
),
columnwise_scale_inv
=
torch
.
zeros
(
columnwise_scale_dims
,
device
=
"cuda"
,
dtype
=
torch
.
float32
),
fp8_dtype
=
fp8_dtype
,
is_2D_scaled
=
is_2D_scaled
,
quantizer
=
quantizer
,
)
assert
list
(
tensor
.
size
())
==
dims
,
"Incorrect dims"
assert
tensor
.
dtype
==
dtype
,
"Incorrect nominal dtype"
assert
tensor
.
is_cuda
,
"Incorrect device"
def
_test_quantize_dequantize
(
self
,
quantizer
:
Float8BlockQuantizer
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
dims
:
DimsType
=
(
23
,
128
),
rtol
:
float
=
0.0
,
atol
:
float
=
0.0
,
dequant_columnwise
:
bool
=
False
,
use_cpp_allocation
:
bool
=
False
,
)
->
None
:
"""Check numerical error when casting to FP8 and back"""
dims
=
_to_list
(
dims
)
# Initialize random data
# Note: Make sure values are not all close to zero, or else
# test may pass trivially.
x_ref
=
2
*
torch
.
rand
(
dims
,
dtype
=
dtype
,
device
=
"cpu"
)
-
1
x_ref
.
view
(
-
1
)[
0
]
=
0.75
x_ref_cuda
=
x_ref
.
to
(
"cuda"
)
# Cast to FP8 and back
if
not
use_cpp_allocation
:
x_fp8
=
quantizer
.
make_empty
(
shape
=
dims
,
device
=
"cuda"
)
quantizer
.
update_quantized
(
x_ref_cuda
,
x_fp8
)
else
:
# This codepath allows the CPP binding to allocate the output
# tensor
x_fp8
=
tex
.
quantize
(
x_ref_cuda
,
quantizer
,
None
,
None
)
if
dequant_columnwise
:
# Strip out rowwise data to verify dequantization of
# columnwise data.
x_fp8
.
update_usage
(
rowwise_usage
=
False
,
columnwise_usage
=
True
)
x_fp8
=
x_fp8
.
dequantize
(
dtype
=
dtype
).
cpu
()
# Check results
torch
.
testing
.
assert_close
(
x_fp8
,
x_ref
,
rtol
=
rtol
,
atol
=
atol
)
# Make sure we are not trivially passing the test
with
pytest
.
raises
(
AssertionError
):
torch
.
testing
.
assert_close
(
x_fp8
,
-
x_ref
,
rtol
=
rtol
,
atol
=
atol
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
_fp8_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
def
test_quantize_dequantize_dtypes
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
block_scaling_dim
:
int
)
->
None
:
atol
=
_tols
[
fp8_dtype
][
"atol"
]
rtol
=
_tols
[
fp8_dtype
][
"rtol"
]
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
False
,
block_scaling_dim
=
block_scaling_dim
,
)
self
.
_test_quantize_dequantize
(
quantizer
=
quantizer
,
dtype
=
dtype
,
atol
=
atol
,
rtol
=
rtol
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
_fp8_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
])
def
test_quantize_dequantize_columnwise_only
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
block_scaling_dim
:
int
)
->
None
:
atol
=
_tols
[
fp8_dtype
][
"atol"
]
rtol
=
_tols
[
fp8_dtype
][
"rtol"
]
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
False
,
columnwise
=
True
,
block_scaling_dim
=
block_scaling_dim
,
)
self
.
_test_quantize_dequantize
(
quantizer
=
quantizer
,
dtype
=
dtype
,
atol
=
atol
,
rtol
=
rtol
,
use_cpp_allocation
=
True
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[],
256
,
311
,
[
264
],
[
256
,
512
],
[
250
,
500
],
[
7
,
5
,
3
],
[
2
,
3
,
5
,
3
]]
)
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"dq_columnwise"
,
[
True
,
False
])
def
test_quantize_dequantize_dims
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
,
dq_columnwise
:
bool
)
->
None
:
atol
=
_tols
[
tex
.
DType
.
kFloat8E4M3
][
"atol"
]
rtol
=
_tols
[
tex
.
DType
.
kFloat8E4M3
][
"rtol"
]
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
,
rowwise
=
True
,
columnwise
=
dq_columnwise
,
block_scaling_dim
=
block_scaling_dim
,
)
self
.
_test_quantize_dequantize
(
quantizer
=
quantizer
,
dims
=
dims
,
atol
=
atol
,
rtol
=
rtol
,
dequant_columnwise
=
dq_columnwise
,
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[],
256
,
311
,
[
264
],
[
256
,
512
],
[
250
,
500
],
[
7
,
5
,
3
],
[
2
,
3
,
5
,
3
]]
)
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
_fp8_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dq_columnwise"
,
[
True
,
False
])
def
test_quantize_dequantize_dims_cpp_allocate_output
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
,
fp8_dtype
:
tex
.
DType
,
dq_columnwise
:
bool
)
->
None
:
atol
=
_tols
[
fp8_dtype
][
"atol"
]
rtol
=
_tols
[
fp8_dtype
][
"rtol"
]
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
dq_columnwise
,
block_scaling_dim
=
block_scaling_dim
,
)
self
.
_test_quantize_dequantize
(
quantizer
=
quantizer
,
dims
=
dims
,
atol
=
atol
,
rtol
=
rtol
,
dequant_columnwise
=
dq_columnwise
,
use_cpp_allocation
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
def
test_data_accessors
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
)
->
None
:
"""Test data accessors of Float8BlockwiseQTensor"""
device
=
"cuda"
dtype
=
torch
.
bfloat16
x_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
y_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
block_scaling_dim
,
)
# Create FP8 tensor
x_fp8
=
quantizer
.
quantize
(
x_hp
)
x_recovered
=
x_fp8
.
data
torch
.
testing
.
assert_close
(
x_recovered
,
x_hp
,
**
_tols
[
fp8_dtype
])
x_fp8
.
data
=
y_hp
y_recovered
=
x_fp8
.
data
torch
.
testing
.
assert_close
(
y_recovered
,
y_hp
,
**
_tols
[
fp8_dtype
])
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
def
test_serialization
(
self
,
dims
:
DimsType
,
block_scaling_dim
:
int
)
->
None
:
"""Test serialization of Float8BlockwiseQTensor"""
device
=
"cuda"
dtype
=
torch
.
bfloat16
x_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
tex
.
DType
.
kFloat8E5M2
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
block_scaling_dim
,
)
# Create FP8 tensor
x_fp8
=
quantizer
.
quantize
(
x_hp
)
# Save tensor
buffer
=
io
.
BytesIO
()
torch
.
save
(
x_fp8
,
buffer
)
# Load tensor
buffer
.
seek
(
0
)
x_fp8_loaded
=
torch
.
load
(
buffer
,
weights_only
=
False
)
# Test that loaded tensor matches original
assert
isinstance
(
x_fp8_loaded
,
Float8BlockwiseQTensor
)
torch
.
testing
.
assert_close
(
x_fp8_loaded
.
_rowwise_data
,
x_fp8
.
_rowwise_data
)
torch
.
testing
.
assert_close
(
x_fp8_loaded
.
_columnwise_data
,
x_fp8
.
_columnwise_data
)
torch
.
testing
.
assert_close
(
x_fp8_loaded
.
_rowwise_scale_inv
,
x_fp8
.
_rowwise_scale_inv
)
torch
.
testing
.
assert_close
(
x_fp8_loaded
.
_columnwise_scale_inv
,
x_fp8
.
_columnwise_scale_inv
)
torch
.
testing
.
assert_close
(
x_fp8_loaded
.
data
,
x_fp8
.
data
)
assert
x_fp8_loaded
.
_is_2D_scaled
==
x_fp8
.
_is_2D_scaled
assert
x_fp8_loaded
.
dtype
==
x_fp8
.
dtype
assert
x_fp8_loaded
.
_fp8_dtype
==
x_fp8
.
_fp8_dtype
# Test that dequantized values match
x_fp8_dequant
=
x_fp8
.
dequantize
()
x_fp8_loaded_dequant
=
x_fp8_loaded
.
dequantize
()
torch
.
testing
.
assert_close
(
x_fp8_loaded_dequant
,
x_fp8_dequant
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
def
test_inplace_ops
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
dims
:
DimsType
,
block_scaling_dim
:
int
)
->
None
:
"""Test in-place operations"""
device
=
"cuda"
x_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
y_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
block_scaling_dim
,
)
# Test in-place add
x_fp8
=
quantizer
.
quantize
(
x_hp
.
clone
())
y_fp8
=
quantizer
.
quantize
(
y_hp
.
clone
())
x_fp8
.
add_
(
y_fp8
)
torch
.
testing
.
assert_close
(
x_fp8
.
dequantize
(),
x_hp
+
y_hp
,
**
_tols
[
fp8_dtype
])
# Test in-place subtract
x_fp8
=
quantizer
.
quantize
(
x_hp
.
clone
())
y_fp8
=
quantizer
.
quantize
(
y_hp
.
clone
())
x_fp8
.
sub_
(
y_fp8
)
torch
.
testing
.
assert_close
(
x_fp8
.
dequantize
(),
x_hp
-
y_hp
,
**
_tols
[
fp8_dtype
])
# Test in-place multiply
x_fp8
=
quantizer
.
quantize
(
x_hp
.
clone
())
y_fp8
=
quantizer
.
quantize
(
y_hp
.
clone
())
x_fp8
.
mul_
(
y_fp8
)
torch
.
testing
.
assert_close
(
x_fp8
.
dequantize
(),
x_hp
*
y_hp
,
**
_tols
[
fp8_dtype
])
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
def
test_out_of_place_ops
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
dims
:
DimsType
,
block_scaling_dim
:
int
)
->
None
:
"""Test out-of-place operations"""
device
=
"cuda"
x_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
y_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
block_scaling_dim
,
)
x_fp8
=
quantizer
.
quantize
(
x_hp
.
clone
())
y_fp8
=
quantizer
.
quantize
(
y_hp
.
clone
())
# Test exact operations
torch
.
testing
.
assert_close
(
-
x_fp8
,
-
x_hp
,
**
_tols
[
fp8_dtype
])
torch
.
testing
.
assert_close
(
x_fp8
.
abs
(),
x_hp
.
abs
(),
**
_tols
[
fp8_dtype
])
# Test elementwise operations
torch
.
testing
.
assert_close
(
x_fp8
+
y_fp8
,
x_hp
+
y_hp
,
**
_tols
[
fp8_dtype
])
torch
.
testing
.
assert_close
(
x_fp8
-
y_fp8
,
x_hp
-
y_hp
,
**
_tols
[
fp8_dtype
])
torch
.
testing
.
assert_close
(
x_fp8
*
y_fp8
,
x_hp
*
y_hp
,
**
_tols
[
fp8_dtype
])
torch
.
testing
.
assert_close
(
torch
.
sin
(
x_fp8
),
torch
.
sin
(
x_hp
),
**
_tols
[
fp8_dtype
])
# Make sure we are not trivially passing tests
with
pytest
.
raises
(
AssertionError
):
torch
.
testing
.
assert_close
(
x_fp8
+
y_fp8
,
x_hp
-
y_hp
,
**
_tols
[
fp8_dtype
])
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
def
test_view_same_shape
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
dims
:
DimsType
,
block_scaling_dim
:
int
)
->
None
:
"""Test view operations that preserve tensor shape"""
device
=
"cuda"
x_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
block_scaling_dim
,
)
x_fp8
=
quantizer
.
make_empty
(
x_hp
.
shape
,
dtype
=
dtype
,
device
=
device
)
quantizer
.
update_quantized
(
x_hp
.
clone
(),
x_fp8
)
# Test view with same shape
x_view
=
x_fp8
.
view
(
*
dims
)
torch
.
testing
.
assert_close
(
x_view
.
dequantize
(),
x_hp
,
**
_tols
[
fp8_dtype
])
assert
x_view
.
shape
==
x_fp8
.
shape
,
"Shape changed after view with same dims"
# Make sure we are not trivially passing tests
with
pytest
.
raises
(
AssertionError
):
torch
.
testing
.
assert_close
(
x_view
.
dequantize
(),
-
x_hp
,
**
_tols
[
fp8_dtype
])
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
def
test_reshape_same_shape
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
dims
:
DimsType
,
block_scaling_dim
:
int
)
->
None
:
"""Test reshape operations that preserve tensor shape"""
device
=
"cuda"
x_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
block_scaling_dim
,
)
x_fp8
=
quantizer
.
make_empty
(
x_hp
.
shape
,
dtype
=
dtype
,
device
=
device
)
quantizer
.
update_quantized
(
x_hp
.
clone
(),
x_fp8
)
# Test reshape with same shape
x_reshape
=
x_fp8
.
reshape
(
*
dims
)
torch
.
testing
.
assert_close
(
x_reshape
.
dequantize
(),
x_hp
,
**
_tols
[
fp8_dtype
])
assert
x_reshape
.
shape
==
x_fp8
.
shape
,
"Shape changed after reshape with same dims"
# Test reshape with -1 canonicalization
new_dims
=
[
-
1
,
dims
[
1
]]
x_reshape
=
x_fp8
.
reshape
(
*
new_dims
)
torch
.
testing
.
assert_close
(
x_reshape
.
dequantize
(),
x_hp
,
**
_tols
[
fp8_dtype
])
assert
x_reshape
.
shape
==
x_fp8
.
shape
,
"Shape changed after reshape with -1"
# Make sure we are not trivially passing tests
with
pytest
.
raises
(
AssertionError
):
torch
.
testing
.
assert_close
(
x_reshape
.
dequantize
(),
-
x_hp
,
**
_tols
[
fp8_dtype
])
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
],
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"dims"
,
[[
256
,
512
],
[
250
,
500
]])
@
pytest
.
mark
.
parametrize
(
"block_scaling_dim"
,
[
1
,
2
])
def
test_clone_detach
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
dims
:
DimsType
,
block_scaling_dim
:
int
)
->
None
:
"""Test clone and detach operations"""
device
=
"cuda"
x_hp
=
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
device
)
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
fp8_dtype
,
rowwise
=
True
,
columnwise
=
True
,
block_scaling_dim
=
block_scaling_dim
,
)
x_fp8
=
quantizer
.
quantize
(
x_hp
.
clone
())
# Test clone
x_clone
=
x_fp8
.
clone
()
torch
.
testing
.
assert_close
(
x_clone
.
dequantize
(),
x_hp
,
**
_tols
[
fp8_dtype
])
assert
x_clone
.
shape
==
x_fp8
.
shape
,
"Shape changed after clone"
# Test detach
x_detach
=
x_fp8
.
detach
()
torch
.
testing
.
assert_close
(
x_detach
.
dequantize
(),
x_hp
,
**
_tols
[
fp8_dtype
])
assert
x_detach
.
shape
==
x_fp8
.
shape
,
"Shape changed after detach"
# Make sure we are not trivially passing tests
with
pytest
.
raises
(
AssertionError
):
torch
.
testing
.
assert_close
(
x_clone
.
dequantize
(),
-
x_hp
,
**
_tols
[
fp8_dtype
])
tests/pytorch/test_float8tensor.py
View file @
ab3e5a92
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
import
io
import
io
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
,
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -158,6 +158,32 @@ class TestFloat8Tensor:
...
@@ -158,6 +158,32 @@ class TestFloat8Tensor:
def
test_quantize_dequantize_dims
(
self
,
dims
:
DimsType
)
->
None
:
def
test_quantize_dequantize_dims
(
self
,
dims
:
DimsType
)
->
None
:
self
.
_test_quantize_dequantize
(
dims
=
dims
)
self
.
_test_quantize_dequantize
(
dims
=
dims
)
@
pytest
.
mark
.
parametrize
(
"fp8_dtype"
,
_fp8_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"noop"
,
[
True
,
False
])
def
test_quantize_dequantize_noop
(
self
,
fp8_dtype
:
tex
.
DType
,
dtype
:
torch
.
dtype
,
noop
:
bool
)
->
None
:
noop_tensor
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
if
noop
:
noop_tensor
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
dims
=
23
scale
:
float
=
3.5
# Initialize random data
x_ref
=
2
*
torch
.
rand
(
_to_list
(
dims
),
dtype
=
dtype
,
device
=
"cpu"
)
-
1
# Cast to FP8 and back
x_fp8
=
to_float8
(
x_ref
,
fp8_dtype
=
fp8_dtype
,
scale
=
scale
)
# if noop, then when we input a different tensor, output should still be x_fp8_orig
x_ref_noop_test
=
2
*
x_ref
.
cuda
()
x_fp8_orig
=
x_fp8
.
clone
()
x_fp8
.
quantize_
(
x_ref_noop_test
,
noop_flag
=
noop_tensor
)
if
noop_tensor
.
item
()
==
1.0
:
torch
.
testing
.
assert_close
(
x_fp8
,
x_fp8_orig
,
atol
=
0
,
rtol
=
0
)
else
:
torch
.
testing
.
assert_close
(
x_fp8
,
x_ref_noop_test
,
**
_tols
[
fp8_dtype
])
def
test_basic_ops
(
def
test_basic_ops
(
self
,
self
,
dims
:
DimsType
=
23
,
dims
:
DimsType
=
23
,
...
...
tests/pytorch/test_fused_optimizer.py
View file @
ab3e5a92
...
@@ -360,6 +360,20 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -360,6 +360,20 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
master_atol
=
2e-3
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_compatible
(),
reason
=
"bf16 if not supported"
)
def
test_bf16_exp_avg
(
self
):
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
param_dtype
=
torch
.
bfloat16
,
use_master_weights
=
True
,
master_weight_dtype
=
torch
.
float32
,
grad_dtype
=
torch
.
float32
,
exp_avg_dtype
=
torch
.
bfloat16
,
exp_avg_sq_dtype
=
torch
.
float32
,
master_rtol
=
2e-3
,
master_atol
=
2e-3
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_compatible
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_compatible
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_exp_avg
(
self
):
def
test_fp8_exp_avg
(
self
):
...
@@ -389,6 +403,20 @@ class TestFusedAdam(TestFusedOptimizer):
...
@@ -389,6 +403,20 @@ class TestFusedAdam(TestFusedOptimizer):
master_atol
=
2e-3
,
master_atol
=
2e-3
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_compatible
(),
reason
=
"bf16 if not supported"
)
def
test_bf16_exp_avg_sq
(
self
):
self
.
gen_precision_aware_test
(
use_fp8_params
=
False
,
param_dtype
=
torch
.
bfloat16
,
use_master_weights
=
True
,
master_weight_dtype
=
torch
.
float32
,
grad_dtype
=
torch
.
float32
,
exp_avg_dtype
=
torch
.
float32
,
exp_avg_sq_dtype
=
torch
.
bfloat16
,
master_rtol
=
2e-3
,
master_atol
=
2e-3
,
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_compatible
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
is_bf16_compatible
(),
reason
=
"bf16 if not supported"
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
def
test_fp8_exp_avg_sq
(
self
):
def
test_fp8_exp_avg_sq
(
self
):
...
...
tests/pytorch/test_fused_rope.py
View file @
ab3e5a92
...
@@ -11,52 +11,6 @@ from transformer_engine.pytorch.dot_product_attention.rope import (
...
@@ -11,52 +11,6 @@ from transformer_engine.pytorch.dot_product_attention.rope import (
)
)
def
_get_thd_freqs_on_this_cp_rank
(
cp_rank
:
int
,
cp_size
:
int
,
x
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
cp_size
>
1
:
cp_seg
=
x
.
size
(
0
)
//
2
full_seqlen
=
cp_size
*
x
.
size
(
0
)
return
torch
.
cat
(
[
freqs
[
cp_rank
*
cp_seg
:
(
cp_rank
+
1
)
*
cp_seg
],
freqs
[
full_seqlen
-
(
cp_rank
+
1
)
*
cp_seg
:
full_seqlen
-
cp_rank
*
cp_seg
],
]
)
else
:
return
freqs
[:
x
.
size
(
0
)]
def
apply_rotary_pos_emb_thd
(
t
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
,
cp_size
:
int
=
1
,
cp_rank
:
int
=
0
,
)
->
torch
.
Tensor
:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
cu_seqlens
=
cu_seqlens
//
cp_size
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
return
torch
.
cat
(
[
apply_rotary_pos_emb
(
x
.
unsqueeze
(
1
),
_get_thd_freqs_on_this_cp_rank
(
cp_rank
,
cp_size
,
x
,
freqs
)
)
for
x
in
torch
.
split
(
t
,
seqlens
)
]
).
squeeze
(
1
)
# Gradient is a broadcasted scalar
# Gradient is a broadcasted scalar
def
_overlapping_grad
(
output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_overlapping_grad
(
output
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
output
.
sum
()
*
2
return
output
.
sum
()
*
2
...
@@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
...
@@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
@
pytest
.
mark
.
parametrize
(
"transpose"
,
[
None
,
(
0
,
1
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
"transpose"
,
[
None
,
(
0
,
1
),
(
2
,
3
)])
@
pytest
.
mark
.
parametrize
(
"tensor_format"
,
[
"sbhd"
,
"bshd"
])
@
pytest
.
mark
.
parametrize
(
"tensor_format"
,
[
"sbhd"
,
"bshd"
])
@
pytest
.
mark
.
parametrize
(
"loss_func"
,
[
_overlapping_grad
,
_non_overlapping_grad
])
@
pytest
.
mark
.
parametrize
(
"loss_func"
,
[
_overlapping_grad
,
_non_overlapping_grad
])
@
pytest
.
mark
.
parametrize
(
"cp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
True
,
False
])
def
test_fused_rope
(
def
test_fused_rope
(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seq_length
:
int
,
seq_length
:
int
,
...
@@ -85,6 +41,8 @@ def test_fused_rope(
...
@@ -85,6 +41,8 @@ def test_fused_rope(
transpose
:
Union
[
Tuple
,
None
],
transpose
:
Union
[
Tuple
,
None
],
tensor_format
:
str
,
tensor_format
:
str
,
loss_func
:
Callable
,
loss_func
:
Callable
,
cp_size
:
int
,
interleaved
:
bool
,
)
->
None
:
)
->
None
:
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
batch_size
,
head_num
=
2
,
64
batch_size
,
head_num
=
2
,
64
...
@@ -99,14 +57,22 @@ def test_fused_rope(
...
@@ -99,14 +57,22 @@ def test_fused_rope(
t
=
t
.
transpose
(
*
transpose
).
contiguous
().
transpose
(
*
transpose
)
t
=
t
.
transpose
(
*
transpose
).
contiguous
().
transpose
(
*
transpose
)
t
.
requires_grad
=
True
t
.
requires_grad
=
True
rotary_pos_emb
=
RotaryPositionEmbedding
(
hidden_size
,
rotary_percent
)
rotary_pos_emb
=
RotaryPositionEmbedding
(
hidden_size
,
rotary_percent
,
interleaved
=
interleaved
)
emb
=
rotary_pos_emb
(
seq_length
)
emb
=
rotary_pos_emb
(
seq_length
*
cp_size
)
assert
emb
.
is_contiguous
()
for
cp_rank
in
range
(
cp_size
):
# unfused
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
# for more accurate comparison
output_unfused
=
apply_rotary_pos_emb
(
output_unfused
=
apply_rotary_pos_emb
(
t
.
float
(),
emb
,
tensor_format
=
tensor_format
,
fused
=
False
t
.
float
(),
emb
,
tensor_format
=
tensor_format
,
interleaved
=
interleaved
,
fused
=
False
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
).
to
(
dtype
)
).
to
(
dtype
)
loss_unfused
=
loss_func
(
output_unfused
)
loss_unfused
=
loss_func
(
output_unfused
)
loss_unfused
.
backward
()
loss_unfused
.
backward
()
...
@@ -118,7 +84,10 @@ def test_fused_rope(
...
@@ -118,7 +84,10 @@ def test_fused_rope(
t
,
t
,
emb
,
emb
,
tensor_format
=
tensor_format
,
tensor_format
=
tensor_format
,
interleaved
=
interleaved
,
fused
=
True
,
fused
=
True
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
)
)
loss_fused
=
loss_func
(
output_fused
)
loss_fused
=
loss_func
(
output_fused
)
loss_fused
.
backward
()
loss_fused
.
backward
()
...
@@ -135,7 +104,8 @@ def test_fused_rope(
...
@@ -135,7 +104,8 @@ def test_fused_rope(
@
pytest
.
mark
.
parametrize
(
"rotary_percent"
,
[
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"rotary_percent"
,
[
0.5
,
1.0
])
@
pytest
.
mark
.
parametrize
(
"transpose"
,
[
None
,
(
1
,
2
)])
@
pytest
.
mark
.
parametrize
(
"transpose"
,
[
None
,
(
1
,
2
)])
@
pytest
.
mark
.
parametrize
(
"loss_func"
,
[
_overlapping_grad
,
_non_overlapping_grad
])
@
pytest
.
mark
.
parametrize
(
"loss_func"
,
[
_overlapping_grad
,
_non_overlapping_grad
])
@
pytest
.
mark
.
parametrize
(
"cp_size"
,
[
1
,
2
,
3
])
@
pytest
.
mark
.
parametrize
(
"cp_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
True
,
False
])
def
test_fused_rope_thd
(
def
test_fused_rope_thd
(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
hidden_size
:
int
,
hidden_size
:
int
,
...
@@ -143,6 +113,7 @@ def test_fused_rope_thd(
...
@@ -143,6 +113,7 @@ def test_fused_rope_thd(
transpose
:
Union
[
Tuple
,
None
],
transpose
:
Union
[
Tuple
,
None
],
loss_func
:
Callable
,
loss_func
:
Callable
,
cp_size
:
int
,
cp_size
:
int
,
interleaved
:
bool
,
)
->
None
:
)
->
None
:
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
batch_size
,
head_num
=
2
,
64
batch_size
,
head_num
=
2
,
64
...
@@ -170,15 +141,23 @@ def test_fused_rope_thd(
...
@@ -170,15 +141,23 @@ def test_fused_rope_thd(
t
=
t
.
transpose
(
*
transpose
).
contiguous
().
transpose
(
*
transpose
)
t
=
t
.
transpose
(
*
transpose
).
contiguous
().
transpose
(
*
transpose
)
t
.
requires_grad
=
True
t
.
requires_grad
=
True
rotary_pos_emb
=
RotaryPositionEmbedding
(
hidden_size
,
rotary_percent
)
rotary_pos_emb
=
RotaryPositionEmbedding
(
hidden_size
,
rotary_percent
,
interleaved
=
interleaved
)
emb
=
rotary_pos_emb
(
cu_seqlens_padded
[
-
1
])
emb
=
rotary_pos_emb
(
cu_seqlens_padded
[
-
1
])
assert
emb
.
is_contiguous
()
for
cp_rank
in
range
(
cp_size
):
for
cp_rank
in
range
(
cp_size
):
# unfused
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
# for more accurate comparison
output_unfused
=
apply_rotary_pos_emb_thd
(
output_unfused
=
apply_rotary_pos_emb
(
t
.
float
(),
cu_seqlens_padded
,
emb
,
cp_size
,
cp_rank
t
.
float
(),
emb
,
tensor_format
=
"thd"
,
interleaved
=
interleaved
,
fused
=
False
,
cu_seqlens
=
cu_seqlens_padded
,
cp_size
=
cp_size
,
cp_rank
=
cp_rank
,
).
to
(
dtype
)
).
to
(
dtype
)
loss_unfused
=
loss_func
(
output_unfused
)
loss_unfused
=
loss_func
(
output_unfused
)
loss_unfused
.
backward
()
loss_unfused
.
backward
()
...
@@ -189,6 +168,7 @@ def test_fused_rope_thd(
...
@@ -189,6 +168,7 @@ def test_fused_rope_thd(
output_fused
=
apply_rotary_pos_emb
(
output_fused
=
apply_rotary_pos_emb
(
t
,
t
,
emb
,
emb
,
interleaved
=
interleaved
,
fused
=
True
,
fused
=
True
,
tensor_format
=
"thd"
,
tensor_format
=
"thd"
,
cu_seqlens
=
cu_seqlens_padded
,
cu_seqlens
=
cu_seqlens_padded
,
...
...
tests/pytorch/test_fusible_ops.py
View file @
ab3e5a92
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
import
io
import
math
import
math
from
typing
import
Optional
from
typing
import
Optional
...
@@ -1405,6 +1406,7 @@ class TestBasicOps:
...
@@ -1405,6 +1406,7 @@ class TestBasicOps:
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
37
,),
(
2
,
13
),
(
32
,
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"out_shape"
,
((
37
,),
(
2
,
13
),
(
32
,
1
,
32
)))
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
_dtypes
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
))
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
))
@
pytest
.
mark
.
parametrize
(
"cache_quantized_input"
,
(
False
,
True
))
def
test_activation
(
def
test_activation
(
self
,
self
,
*
,
*
,
...
@@ -1413,6 +1415,7 @@ class TestBasicOps:
...
@@ -1413,6 +1415,7 @@ class TestBasicOps:
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
=
"cuda"
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantization
:
Optional
[
str
],
cache_quantized_input
:
bool
,
)
->
None
:
)
->
None
:
"""Activation functions"""
"""Activation functions"""
...
@@ -1424,6 +1427,8 @@ class TestBasicOps:
...
@@ -1424,6 +1427,8 @@ class TestBasicOps:
# Skip invalid configurations
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
if
cache_quantized_input
:
maybe_skip_quantization
(
"fp8"
,
device
=
device
)
# Random data
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
@@ -1432,15 +1437,17 @@ class TestBasicOps:
...
@@ -1432,15 +1437,17 @@ class TestBasicOps:
test_device
=
device
,
test_device
=
device
,
test_is_fp8
=
quantized_compute
,
test_is_fp8
=
quantized_compute
,
)
)
if
quantized_compute
:
with
torch
.
no_grad
():
x_test
=
x_test
.
dequantize
().
requires_grad_
()
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
dy_ref
,
dy_test
=
make_reference_and_test_tensors
(
out_shape
,
out_shape
,
test_dtype
=
dtype
,
test_dtype
=
dtype
,
test_device
=
device
,
test_device
=
device
,
test_is_fp8
=
quantized_compute
,
requires_grad
=
False
,
requires_grad
=
False
,
)
)
if
quantized_compute
:
with
torch
.
no_grad
():
x_test
=
x_test
.
dequantize
().
requires_grad_
()
dy_test
=
dy_test
.
dequantize
()
# Plain PyTorch implementation
# Plain PyTorch implementation
y_ref
:
torch
.
Tensor
y_ref
:
torch
.
Tensor
...
@@ -1471,7 +1478,8 @@ class TestBasicOps:
...
@@ -1471,7 +1478,8 @@ class TestBasicOps:
swiglu
=
te_ops
.
SwiGLU
,
swiglu
=
te_ops
.
SwiGLU
,
)[
activation
]
)[
activation
]
forward
=
te_ops
.
Sequential
(
forward
=
te_ops
.
Sequential
(
make_op
(),
te_ops
.
Quantize
(
forward
=
False
,
backward
=
quantized_compute
),
make_op
(
cache_quantized_input
=
cache_quantized_input
),
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
te_ops
.
Quantize
(
forward
=
quantized_compute
,
backward
=
False
),
)
)
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
...
@@ -1480,9 +1488,9 @@ class TestBasicOps:
...
@@ -1480,9 +1488,9 @@ class TestBasicOps:
# Expected numerical error
# Expected numerical error
tols
=
dtype_tols
(
dtype
)
tols
=
dtype_tols
(
dtype
)
if
quantized_compute
:
if
quantized_compute
or
cache_quantized_input
:
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
tols
=
dtype_tols
(
tex
.
DType
.
kFloat8E4M3
)
if
activation
==
"relu"
:
if
activation
==
"relu"
and
not
cache_quantized_input
:
tols
=
{
"atol"
:
0
,
"rtol"
:
0
}
tols
=
{
"atol"
:
0
,
"rtol"
:
0
}
# Check results
# Check results
...
@@ -1894,3 +1902,118 @@ class TestFusedOps:
...
@@ -1894,3 +1902,118 @@ class TestFusedOps:
torch
.
testing
.
assert_close
(
y2_test
,
y2_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
y2_test
,
y2_ref
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dx_test
,
x_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
torch
.
testing
.
assert_close
(
dw_test
,
w_ref
.
grad
,
**
tols
)
class
TestCheckpointing
:
"""Tests for checkpointing"""
@
staticmethod
def
setup_class
(
cls
)
->
None
:
# Configure RNG
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
@
pytest
.
mark
.
parametrize
(
"quantization"
,
(
None
,
"fp8"
,
"mxfp8"
))
@
pytest
.
mark
.
parametrize
(
"quantized_weight"
,
(
False
,
True
))
def
test_linear
(
self
,
*
,
pre_checkpoint_steps
:
int
=
2
,
post_checkpoint_steps
:
int
=
2
,
weight_shape
:
tuple
[
int
,
int
]
=
(
32
,
32
),
in_shape
:
Iterable
[
int
]
=
(
32
,
-
1
),
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
"cuda"
,
quantization
:
Optional
[
str
],
quantized_weight
:
bool
,
)
->
None
:
"""Check checkpointing with linear op"""
# Make input and weight shapes consistent
out_features
,
in_features
=
weight_shape
in_shape
=
list
(
in_shape
)[:
-
1
]
+
[
in_features
]
out_shape
=
in_shape
[:
-
1
]
+
[
out_features
]
# Skip invalid configurations
quantized_compute
=
quantization
is
not
None
maybe_skip_quantization
(
quantization
,
dims
=
in_shape
,
device
=
device
)
maybe_skip_quantization
(
quantization
,
dims
=
out_shape
)
# Construct model
recipe
=
make_recipe
(
quantization
)
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model_save
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
)
optim_save
=
torch
.
optim
.
SGD
(
model_save
.
parameters
(),
lr
=
0.25
)
# Warmup training steps
for
_
in
range
(
pre_checkpoint_steps
):
x
=
torch
.
randn
(
in_shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
dy
=
torch
.
randn
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
optim_save
.
zero_grad
()
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
y
=
model_save
(
x
)
y
.
backward
(
dy
)
optim_save
.
step
()
# Save checkpoint
byte_stream
=
io
.
BytesIO
()
torch
.
save
(
{
"model"
:
model_save
.
state_dict
(),
"optim"
:
optim_save
.
state_dict
()},
byte_stream
,
)
checkpoint_bytes
=
byte_stream
.
getvalue
()
del
byte_stream
# Synthetic data for evaluation
xs_save
=
[
torch
.
randn
(
in_shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
for
_
in
range
(
post_checkpoint_steps
)
]
with
torch
.
no_grad
():
xs_load
=
[
x
.
clone
().
requires_grad_
()
for
x
in
xs_save
]
dys
=
[
torch
.
randn
(
out_shape
,
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
post_checkpoint_steps
)
]
# Training steps with original model
ys_save
=
[]
for
i
in
range
(
post_checkpoint_steps
):
optim_save
.
zero_grad
()
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
y
=
model_save
(
xs_save
[
i
])
y
.
backward
(
dys
[
i
])
optim_save
.
step
()
ys_save
.
append
(
y
)
# Load checkpoint
with
te
.
fp8_model_init
(
enabled
=
quantized_weight
,
recipe
=
recipe
):
model_load
=
te_ops
.
Sequential
(
te_ops
.
Linear
(
in_features
,
out_features
,
device
=
device
,
dtype
=
dtype
)
)
optim_load
=
torch
.
optim
.
SGD
(
model_load
.
parameters
(),
lr
=
0.25
)
state_dict
=
torch
.
load
(
io
.
BytesIO
(
checkpoint_bytes
),
weights_only
=
False
)
model_load
.
load_state_dict
(
state_dict
[
"model"
])
optim_load
.
load_state_dict
(
state_dict
[
"optim"
])
# Training steps with loaded model
ys_load
=
[]
for
i
in
range
(
post_checkpoint_steps
):
optim_load
.
zero_grad
()
with
te
.
fp8_autocast
(
enabled
=
quantized_compute
,
fp8_recipe
=
recipe
):
y
=
model_load
(
xs_load
[
i
])
y
.
backward
(
dys
[
i
])
optim_load
.
step
()
ys_load
.
append
(
y
)
# Check that original and loaded model match exactly
tols
=
{
"rtol"
:
0
,
"atol"
:
0
}
for
param_load
,
param_save
in
zip
(
model_load
.
parameters
(),
model_save
.
parameters
()):
torch
.
testing
.
assert_close
(
param_load
,
param_save
,
**
tols
)
torch
.
testing
.
assert_close
(
param_load
.
grad
,
param_save
.
grad
,
**
tols
)
for
y_load
,
y_save
in
zip
(
ys_load
,
ys_save
):
torch
.
testing
.
assert_close
(
y_load
,
y_save
,
**
tols
)
for
x_load
,
x_save
in
zip
(
xs_load
,
xs_save
):
torch
.
testing
.
assert_close
(
x_load
.
grad
,
x_save
.
grad
,
**
tols
)
tests/pytorch/test_multi_tensor.py
View file @
ab3e5a92
...
@@ -9,9 +9,10 @@ import transformer_engine.pytorch as te
...
@@ -9,9 +9,10 @@ import transformer_engine.pytorch as te
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
from
transformer_engine.pytorch.optimizers
import
MultiTensorApply
from
references.
ref_per_tensor_cs
import
ref_compute_scale_and_scale_inv
_from_amax
from
references.
quantize_scale_calc
import
scale
_from_amax
_tensor
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
input_size_pairs
=
[
input_size_pairs
=
[
(
7777
*
77
,
555
*
555
),
(
7777
*
77
,
555
*
555
),
(
777
,
555
),
(
777
,
555
),
...
@@ -224,17 +225,18 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
...
@@ -224,17 +225,18 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
@
pytest
.
mark
.
parametrize
(
"input_size_pair"
,
input_size_pairs
+
[(
1
,
1
)])
@
pytest
.
mark
.
parametrize
(
"input_size_pair"
,
input_size_pairs
+
[(
1
,
1
)])
@
pytest
.
mark
.
parametrize
(
"applier"
,
appliers
)
@
pytest
.
mark
.
parametrize
(
"applier"
,
appliers
)
@
pytest
.
mark
.
parametrize
(
"repeat"
,
[
1
,
55
])
@
pytest
.
mark
.
parametrize
(
"repeat"
,
[
1
,
55
])
@
pytest
.
mark
.
parametrize
(
"
max_fp8"
,
[
448.0
if
not
IS_HIP_EXTENSION
else
240.0
,
57344.0
])
@
pytest
.
mark
.
parametrize
(
"
fp8_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"pow_2_scales"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"epsilon"
,
[
0.0
,
100.0
])
@
pytest
.
mark
.
parametrize
(
"epsilon"
,
[
0.0
,
100.0
])
def
test_multi_tensor_compute_scale_and_scale_inv
(
def
test_multi_tensor_compute_scale_and_scale_inv
(
input_size_pair
,
applier
,
repeat
,
max_fp8
,
pow_2_scales
,
epsilon
input_size_pair
,
applier
,
repeat
,
fp8_dtype
,
pow_2_scales
,
epsilon
):
):
sizea
,
sizeb
=
input_size_pair
sizea
,
sizeb
=
input_size_pair
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
)
overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
device
)
a
=
torch
.
randn
([
sizea
],
dtype
=
torch
.
float32
,
device
=
device
).
abs
()
a
=
torch
.
randn
([
sizea
],
dtype
=
torch
.
float32
,
device
=
device
).
abs
()
b
=
torch
.
randn
([
sizeb
],
dtype
=
torch
.
float32
,
device
=
device
).
abs
()
b
=
torch
.
randn
([
sizeb
],
dtype
=
torch
.
float32
,
device
=
device
).
abs
()
max_fp8
=
torch
.
finfo
(
fp8_dtype
).
max
amax_list
=
[]
amax_list
=
[]
for
i
in
range
(
repeat
):
for
i
in
range
(
repeat
):
...
@@ -253,8 +255,8 @@ def test_multi_tensor_compute_scale_and_scale_inv(
...
@@ -253,8 +255,8 @@ def test_multi_tensor_compute_scale_and_scale_inv(
)
)
for
amax
,
scale
,
scale_inv
in
zip
(
amax_list
,
scale_list
,
scale_inv_list
):
for
amax
,
scale
,
scale_inv
in
zip
(
amax_list
,
scale_list
,
scale_inv_list
):
scale_ref
,
scale_inv_ref
=
ref_compute_scale_and_scale_inv
_from_amax
(
scale_ref
,
scale_inv_ref
,
_
=
scale
_from_amax
_tensor
(
amax
,
max
_
fp8
,
epsilon
,
pow_2_scales
torch
.
float32
,
a
max
,
fp8
_dtype
,
eps
=
epsilon
,
pow_2_scales
=
pow_2_scales
)
)
torch
.
testing
.
assert_close
(
scale
,
scale_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
scale
,
scale_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
scale_inv
,
scale_inv_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
scale_inv
,
scale_inv_ref
,
rtol
=
0
,
atol
=
0
)
tests/pytorch/test_numerics.py
View file @
ab3e5a92
...
@@ -52,6 +52,9 @@ import transformer_engine_torch as tex
...
@@ -52,6 +52,9 @@ import transformer_engine_torch as tex
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
...
@@ -108,6 +111,7 @@ fp8_recipes = [
...
@@ -108,6 +111,7 @@ fp8_recipes = [
recipe
.
MXFP8BlockScaling
(),
recipe
.
MXFP8BlockScaling
(),
recipe
.
DelayedScaling
(),
recipe
.
DelayedScaling
(),
recipe
.
Float8CurrentScaling
(),
recipe
.
Float8CurrentScaling
(),
recipe
.
Float8BlockScaling
(),
]
]
...
@@ -567,6 +571,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
...
@@ -567,6 +571,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -679,6 +685,8 @@ def test_gpt_full_activation_recompute(
...
@@ -679,6 +685,8 @@ def test_gpt_full_activation_recompute(
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -1032,7 +1040,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
...
@@ -1032,7 +1040,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type):
assert_allclose
(
te_output
,
torch_output
,
atol
[
dtype
],
rtol
[
dtype
])
assert_allclose
(
te_output
,
torch_output
,
atol
[
dtype
],
rtol
[
dtype
])
def
_test_granular_accuracy
(
block
,
bs
,
dtype
,
config
):
def
_test_granular_accuracy
(
block
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
):
reset_rng_states
()
reset_rng_states
()
inp_hidden_states
=
torch
.
randn
(
inp_hidden_states
=
torch
.
randn
(
...
@@ -1048,11 +1056,17 @@ def _test_granular_accuracy(block, bs, dtype, config):
...
@@ -1048,11 +1056,17 @@ def _test_granular_accuracy(block, bs, dtype, config):
out
=
out
[
0
]
out
=
out
[
0
]
loss
=
out
.
sum
()
loss
=
out
.
sum
()
loss
.
backward
()
loss
.
backward
()
if
delay_wgrad_compute
:
block
.
backward_dw
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
outputs
=
[
out
,
inp_hidden_states
.
grad
]
outputs
=
[
out
,
inp_hidden_states
.
grad
]
for
p
in
block
.
parameters
():
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
if
p
.
requires_grad
:
if
getattr
(
p
,
"main_grad"
,
None
)
is
not
None
:
outputs
.
append
(
p
.
main_grad
)
assert
p
.
grad
is
None
# grad should be None if fuse_wgrad_accumulation is True
else
:
outputs
.
append
(
p
.
grad
)
outputs
.
append
(
p
.
grad
)
return
outputs
return
outputs
...
@@ -1187,6 +1201,54 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias):
...
@@ -1187,6 +1201,54 @@ def test_linear_accuracy(dtype, bs, model, return_bias, bias):
assert_allclose
(
te_output
,
torch_output
,
tolerance
,
rtol
[
dtype
])
assert_allclose
(
te_output
,
torch_output
,
tolerance
,
rtol
[
dtype
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
def
test_linear_accuracy_delay_wgrad_compute
(
dtype
,
bs
,
model
,
bias
,
fuse_wgrad_accumulation
):
config
=
model_configs
[
model
]
te_linear_ref
=
Linear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
bias
,
params_dtype
=
dtype
,
device
=
"cuda"
,
delay_wgrad_compute
=
False
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
te_linear
=
Linear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
bias
,
params_dtype
=
dtype
,
device
=
"cuda"
,
delay_wgrad_compute
=
True
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
# Share params
with
torch
.
no_grad
():
te_linear_ref
.
weight
=
Parameter
(
te_linear
.
weight
.
clone
())
if
bias
:
te_linear_ref
.
bias
=
Parameter
(
te_linear
.
bias
.
clone
())
if
fuse_wgrad_accumulation
:
weight
=
getattr
(
te_linear
,
f
"weight"
)
weight
.
main_grad
=
torch
.
rand_like
(
weight
,
dtype
=
torch
.
float32
)
te_linear_ref
.
weight
.
main_grad
=
weight
.
main_grad
.
clone
()
te_outputs
=
_test_granular_accuracy
(
te_linear
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
True
)
te_outputs_ref
=
_test_granular_accuracy
(
te_linear_ref
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
)
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
te_outputs
,
te_outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
...
@@ -1372,6 +1434,67 @@ def test_layernorm_linear_accuracy(
...
@@ -1372,6 +1434,67 @@ def test_layernorm_linear_accuracy(
assert_allclose
(
te_output
,
torch_output
,
atol
[
dtype
],
rtol
[
dtype
])
assert_allclose
(
te_output
,
torch_output
,
atol
[
dtype
],
rtol
[
dtype
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"zero_centered_gamma"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
def
test_layernorm_linear_accuracy_delay_wgrad_compute
(
dtype
,
bs
,
model
,
normalization
,
zero_centered_gamma
,
bias
,
fuse_wgrad_accumulation
):
config
=
model_configs
[
model
]
ln_linear_ref
=
LayerNormLinear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
eps
,
bias
=
bias
,
normalization
=
normalization
,
params_dtype
=
dtype
,
zero_centered_gamma
=
zero_centered_gamma
,
device
=
"cuda"
,
delay_wgrad_compute
=
False
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
ln_linear
=
LayerNormLinear
(
config
.
hidden_size
,
4
*
config
.
hidden_size
,
config
.
eps
,
bias
=
bias
,
normalization
=
normalization
,
params_dtype
=
dtype
,
zero_centered_gamma
=
zero_centered_gamma
,
device
=
"cuda"
,
delay_wgrad_compute
=
True
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
# Share params
with
torch
.
no_grad
():
ln_linear_ref
.
layer_norm_weight
=
Parameter
(
ln_linear
.
layer_norm_weight
.
clone
())
if
normalization
!=
"RMSNorm"
:
ln_linear_ref
.
layer_norm_bias
=
Parameter
(
ln_linear
.
layer_norm_bias
.
clone
())
ln_linear_ref
.
weight
=
Parameter
(
ln_linear
.
weight
.
clone
())
if
bias
:
ln_linear_ref
.
bias
=
Parameter
(
ln_linear
.
bias
.
clone
())
if
fuse_wgrad_accumulation
:
weight
=
getattr
(
ln_linear
,
f
"weight"
)
weight
.
main_grad
=
torch
.
rand_like
(
weight
,
dtype
=
torch
.
float32
)
ln_linear_ref
.
weight
.
main_grad
=
weight
.
main_grad
.
clone
()
te_outputs
=
_test_granular_accuracy
(
ln_linear
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
True
)
te_outputs_ref
=
_test_granular_accuracy
(
ln_linear_ref
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
)
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
te_outputs
,
te_outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
...
@@ -1448,8 +1571,78 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
...
@@ -1448,8 +1571,78 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret
assert_allclose
(
te_output
,
torch_output
,
atol
[
dtype
],
rtol
[
dtype
])
assert_allclose
(
te_output
,
torch_output
,
atol
[
dtype
],
rtol
[
dtype
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"small"
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
all_activations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
def
test_layernorm_mlp_accuracy_delay_wgrad_compute
(
dtype
,
bs
,
model
,
activation
,
normalization
,
bias
,
fuse_wgrad_accumulation
):
config
=
model_configs
[
model
]
ln_mlp
=
LayerNormMLP
(
hidden_size
=
config
.
hidden_size
,
ffn_hidden_size
=
4
*
config
.
hidden_size
,
eps
=
config
.
eps
,
bias
=
bias
,
normalization
=
normalization
,
params_dtype
=
dtype
,
device
=
"cuda"
,
delay_wgrad_compute
=
True
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
ln_mlp_ref
=
LayerNormMLP
(
hidden_size
=
config
.
hidden_size
,
ffn_hidden_size
=
4
*
config
.
hidden_size
,
eps
=
config
.
eps
,
bias
=
bias
,
normalization
=
normalization
,
params_dtype
=
dtype
,
device
=
"cuda"
,
delay_wgrad_compute
=
False
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
).
eval
()
# Share params
with
torch
.
no_grad
():
ln_mlp_ref
.
layer_norm_weight
=
Parameter
(
ln_mlp
.
layer_norm_weight
.
clone
())
if
normalization
!=
"RMSNorm"
:
ln_mlp_ref
.
layer_norm_bias
=
Parameter
(
ln_mlp
.
layer_norm_bias
.
clone
())
ln_mlp_ref
.
fc1_weight
=
Parameter
(
ln_mlp
.
fc1_weight
.
clone
())
ln_mlp_ref
.
fc2_weight
=
Parameter
(
ln_mlp
.
fc2_weight
.
clone
())
if
bias
:
ln_mlp_ref
.
fc1_bias
=
Parameter
(
ln_mlp
.
fc1_bias
.
clone
())
ln_mlp_ref
.
fc2_bias
=
Parameter
(
ln_mlp
.
fc2_bias
.
clone
())
if
fuse_wgrad_accumulation
:
ln_mlp
.
fc1_weight
.
main_grad
=
torch
.
rand_like
(
ln_mlp
.
fc1_weight
,
dtype
=
torch
.
float32
)
ln_mlp_ref
.
fc1_weight
.
main_grad
=
ln_mlp
.
fc1_weight
.
main_grad
.
clone
()
ln_mlp
.
fc2_weight
.
main_grad
=
torch
.
rand_like
(
ln_mlp
.
fc2_weight
,
dtype
=
torch
.
float32
)
ln_mlp_ref
.
fc2_weight
.
main_grad
=
ln_mlp
.
fc2_weight
.
main_grad
.
clone
()
te_outputs
=
_test_granular_accuracy
(
ln_mlp
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
True
)
te_outputs_ref
=
_test_granular_accuracy
(
ln_mlp_ref
,
bs
,
dtype
,
config
,
delay_wgrad_compute
=
False
)
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
te_outputs
,
te_outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
def
_test_grouped_linear_accuracy
(
def
_test_grouped_linear_accuracy
(
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
=
False
,
):
):
reset_rng_states
()
reset_rng_states
()
if
fp8
:
if
fp8
:
...
@@ -1466,7 +1659,6 @@ def _test_grouped_linear_accuracy(
...
@@ -1466,7 +1659,6 @@ def _test_grouped_linear_accuracy(
if
num_gemms
>
1
:
if
num_gemms
>
1
:
split_size
=
1
split_size
=
1
if
fp8
:
if
fp8
:
if
recipe
.
delayed
():
split_size
=
16
split_size
=
16
if
recipe
.
mxfp8
():
if
recipe
.
mxfp8
():
split_size
=
128
split_size
=
128
...
@@ -1492,6 +1684,12 @@ def _test_grouped_linear_accuracy(
...
@@ -1492,6 +1684,12 @@ def _test_grouped_linear_accuracy(
)
)
loss
=
out
.
sum
()
loss
=
out
.
sum
()
loss
.
backward
()
loss
.
backward
()
if
delay_wgrad_compute
:
if
isinstance
(
block
,
GroupedLinear
):
block
.
backward_dw
()
else
:
for
i
in
range
(
num_gemms
):
block
[
i
].
backward_dw
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
outputs
=
[
out
,
inp_hidden_states
.
grad
]
outputs
=
[
out
,
inp_hidden_states
.
grad
]
...
@@ -1505,33 +1703,34 @@ def _test_grouped_linear_accuracy(
...
@@ -1505,33 +1703,34 @@ def _test_grouped_linear_accuracy(
return
outputs
return
outputs
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
,
ids
=
str
)
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
3
,
6
])
@
pytest
.
mark
.
parametrize
(
"num_gemms"
,
[
3
,
6
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"bs"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"fp8"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
+
[
None
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"delay_wgrad_compute"
,
all_boolean
)
def
test_grouped_linear_accuracy
(
def
test_grouped_linear_accuracy
(
dtype
,
dtype
,
num_gemms
,
num_gemms
,
bs
,
bs
,
model
,
model
,
fp8
,
recipe
,
recipe
,
fp8_model_params
,
fp8_model_params
,
fuse_wgrad_accumulation
,
fuse_wgrad_accumulation
,
bias
,
delay_wgrad_compute
,
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
mxfp8
():
# TODO(ksivamani): debug mismatches
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
"MXFP8 unsupported for grouped linear."
)
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
float8_current_scaling
():
pytest
.
skip
(
"Float8 Current Scaling unsupported for grouped linear."
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
seq_len
%
16
!=
0
and
fp8
:
if
config
.
seq_len
%
16
!=
0
and
fp8
:
...
@@ -1542,18 +1741,19 @@ def test_grouped_linear_accuracy(
...
@@ -1542,18 +1741,19 @@ def test_grouped_linear_accuracy(
num_gemms
,
num_gemms
,
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
True
,
bias
=
bias
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
delay_wgrad_compute
=
delay_wgrad_compute
,
).
eval
()
).
eval
()
sequential_linear
=
torch
.
nn
.
ModuleList
(
sequential_linear
=
torch
.
nn
.
ModuleList
(
[
[
Linear
(
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
4
*
config
.
hidden_size
,
4
*
config
.
hidden_size
,
bias
=
True
,
bias
=
bias
,
params_dtype
=
dtype
,
params_dtype
=
dtype
,
parallel_mode
=
parallel_mode
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -1567,6 +1767,7 @@ def test_grouped_linear_accuracy(
...
@@ -1567,6 +1767,7 @@ def test_grouped_linear_accuracy(
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
i
in
range
(
num_gemms
):
for
i
in
range
(
num_gemms
):
sequential_linear
[
i
].
weight
=
Parameter
(
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
).
clone
())
sequential_linear
[
i
].
weight
=
Parameter
(
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
).
clone
())
if
bias
:
sequential_linear
[
i
].
bias
=
Parameter
(
getattr
(
grouped_linear
,
f
"bias
{
i
}
"
).
clone
())
sequential_linear
[
i
].
bias
=
Parameter
(
getattr
(
grouped_linear
,
f
"bias
{
i
}
"
).
clone
())
if
fuse_wgrad_accumulation
:
if
fuse_wgrad_accumulation
:
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
...
@@ -1578,10 +1779,26 @@ def test_grouped_linear_accuracy(
...
@@ -1578,10 +1779,26 @@ def test_grouped_linear_accuracy(
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"1"
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"1"
outputs_ref
=
_test_grouped_linear_accuracy
(
outputs_ref
=
_test_grouped_linear_accuracy
(
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
)
)
outputs
=
_test_grouped_linear_accuracy
(
outputs
=
_test_grouped_linear_accuracy
(
grouped_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
grouped_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
)
)
# Shoule be bit-wise match
# Shoule be bit-wise match
...
@@ -1589,24 +1806,7 @@ def test_grouped_linear_accuracy(
...
@@ -1589,24 +1806,7 @@ def test_grouped_linear_accuracy(
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
@
pytest
.
mark
.
parametrize
(
"parallel_mode"
,
[
"column"
,
"row"
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
+
[
None
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
def
test_grouped_linear_accuracy_parallel_mode
(
parallel_mode
,
recipe
):
"""Split the tests to save CI time"""
test_grouped_linear_accuracy
(
dtype
=
torch
.
float32
,
num_gemms
=
6
,
bs
=
2
,
model
=
"126m"
,
fp8
=
True
,
recipe
=
recipe
,
fp8_model_params
=
True
,
parallel_mode
=
parallel_mode
,
fuse_wgrad_accumulation
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
def
test_grouped_linear_accuracy_single_gemm
(
recipe
):
def
test_grouped_linear_accuracy_single_gemm
(
recipe
):
"""Split the tests to save CI time"""
"""Split the tests to save CI time"""
test_grouped_linear_accuracy
(
test_grouped_linear_accuracy
(
...
@@ -1614,19 +1814,23 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
...
@@ -1614,19 +1814,23 @@ def test_grouped_linear_accuracy_single_gemm(recipe):
num_gemms
=
1
,
num_gemms
=
1
,
bs
=
2
,
bs
=
2
,
model
=
"126m"
,
model
=
"126m"
,
fp8
=
True
,
recipe
=
recipe
,
recipe
=
recipe
,
fp8_model_params
=
True
,
fp8_model_params
=
True
,
fuse_wgrad_accumulation
=
True
,
fuse_wgrad_accumulation
=
True
,
bias
=
True
,
delay_wgrad_compute
=
False
,
)
)
def
_test_padding_grouped_linear_accuracy
(
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
=
False
):
def
_test_padding_grouped_linear_accuracy
(
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
=
False
):
def
_pad_tensor_for_fp8
(
hidden_states
,
tokens_per_expert
):
def
_pad_tensor_for_fp8
(
hidden_states
,
tokens_per_expert
):
"""Padding tensor shapes to multiples of 16."""
align_size
=
16
if
recipe
.
mxfp8
():
align_size
=
32
padded_tokens_per_expert
=
[
padded_tokens_per_expert
=
[
(
num_tokens
+
15
)
//
16
*
16
for
num_tokens
in
tokens_per_expert
(
num_tokens
+
align_size
-
1
)
//
align_size
*
align_size
for
num_tokens
in
tokens_per_expert
]
]
hidden_states
=
torch
.
split
(
hidden_states
,
tokens_per_expert
)
hidden_states
=
torch
.
split
(
hidden_states
,
tokens_per_expert
)
padded_hidden_states
=
[]
padded_hidden_states
=
[]
...
@@ -1727,10 +1931,8 @@ def test_padding_grouped_linear_accuracy(
...
@@ -1727,10 +1931,8 @@ def test_padding_grouped_linear_accuracy(
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
mxfp8
():
# TODO(ksivamani): debug mismatches
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
"MXFP8 unsupported for grouped linear."
)
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
recipe
.
float8_current_scaling
():
pytest
.
skip
(
"Float8 Current Scaling unsupported for grouped linear."
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
seq_len
%
16
!=
0
and
fp8
:
if
config
.
seq_len
%
16
!=
0
and
fp8
:
...
@@ -1941,6 +2143,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
...
@@ -1941,6 +2143,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
...
tests/pytorch/test_permutation.py
View file @
ab3e5a92
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
import
pytest
import
pytest
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch
import
(
from
transformer_engine.pytorch
import
(
moe_permute
as
te_permute
,
moe_permute
as
te_permute
,
moe_permute_with_probs
as
te_permute_with_probs
,
moe_permute_with_probs
as
te_permute_with_probs
,
...
@@ -17,9 +18,14 @@ from transformer_engine.pytorch import (
...
@@ -17,9 +18,14 @@ from transformer_engine.pytorch import (
)
)
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.fp8
import
FP8GlobalStateManager
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Quantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
import
copy
seed
=
1234
seed
=
1234
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
...
@@ -234,7 +240,6 @@ def _test_permutation_index_map(
...
@@ -234,7 +240,6 @@ def _test_permutation_index_map(
f
" token:
{
num_tokens
}
hidden_size:
{
hidden_size
}
expert:
{
num_expert
}
topK:
{
topK
}
{
te_dtype
}
"
f
" token:
{
num_tokens
}
hidden_size:
{
hidden_size
}
expert:
{
num_expert
}
topK:
{
topK
}
{
te_dtype
}
"
)
)
fp8
=
False
# Convert TE dtypes to PyTorch dtypes
# Convert TE dtypes to PyTorch dtypes
if
te_dtype
==
tex
.
DType
.
kFloat32
:
if
te_dtype
==
tex
.
DType
.
kFloat32
:
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
@@ -242,45 +247,9 @@ def _test_permutation_index_map(
...
@@ -242,45 +247,9 @@ def _test_permutation_index_map(
dtype
=
torch
.
float16
dtype
=
torch
.
float16
elif
te_dtype
==
tex
.
DType
.
kBFloat16
:
elif
te_dtype
==
tex
.
DType
.
kBFloat16
:
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
elif
fp8_available
and
(
te_dtype
==
tex
.
DType
.
kFloat8E5M2
or
te_dtype
==
tex
.
DType
.
kFloat8E4M3
):
dtype
=
torch
.
uint8
fp8
=
True
else
:
else
:
pytest
.
skip
(
"Invalid dtype."
)
pytest
.
skip
(
"Invalid dtype."
)
if
fp8
:
permute_fwd_input
=
torch
.
rand
(
size
=
(
num_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
permute_bwd_input
=
torch
.
rand
(
size
=
(
num_out_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
unpermute_bwd_input
=
torch
.
rand
(
size
=
(
num_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
_permute_fwd_input_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
_permute_bwd_input_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
_unpermute_bwd_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
permute_fwd_input
=
_permute_fwd_input_quantizer
(
permute_fwd_input
)
permute_bwd_input
=
_permute_bwd_input_quantizer
(
permute_bwd_input
)
unpermute_bwd_input
=
_unpermute_bwd_quantizer
(
unpermute_bwd_input
)
pytorch_permute_fwd_input
=
permute_fwd_input
.
dequantize
(
dtype
=
torch
.
float16
)
pytorch_permute_bwd_input
=
permute_bwd_input
.
dequantize
(
dtype
=
torch
.
float16
)
pytorch_unpermute_bwd_input
=
unpermute_bwd_input
.
dequantize
(
dtype
=
torch
.
float16
)
else
:
pytorch_permute_fwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_permute_fwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_permute_bwd_input
=
torch
.
rand
((
num_out_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_permute_bwd_input
=
torch
.
rand
((
num_out_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_unpermute_bwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_unpermute_bwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
...
@@ -323,9 +292,9 @@ def _test_permutation_index_map(
...
@@ -323,9 +292,9 @@ def _test_permutation_index_map(
# TE Permutation
# TE Permutation
#
#
###################################################################################################################################
###################################################################################################################################
te_permute_fwd_input
=
permute_fwd_input
if
fp8
else
pytorch_permute_fwd_input
.
detach
()
te_permute_fwd_input
=
pytorch_permute_fwd_input
.
detach
()
te_permute_fwd_input
.
requires_grad_
(
True
)
te_permute_fwd_input
.
requires_grad_
(
True
)
te_permute_bwd_input
=
permute_bwd_input
if
fp8
else
pytorch_permute_bwd_input
.
detach
()
te_permute_bwd_input
=
pytorch_permute_bwd_input
.
detach
()
te_permute_output
,
row_id_map
=
te_permute
(
te_permute_output
,
row_id_map
=
te_permute
(
te_permute_fwd_input
,
indices
,
num_out_tokens
,
map_type
=
"index"
te_permute_fwd_input
,
indices
,
num_out_tokens
,
map_type
=
"index"
...
@@ -338,7 +307,7 @@ def _test_permutation_index_map(
...
@@ -338,7 +307,7 @@ def _test_permutation_index_map(
te_probs
.
requires_grad_
(
True
)
te_probs
.
requires_grad_
(
True
)
te_unpermute_fwd_input
=
te_permute_output
.
detach
()
te_unpermute_fwd_input
=
te_permute_output
.
detach
()
te_unpermute_fwd_input
.
requires_grad_
(
True
)
te_unpermute_fwd_input
.
requires_grad_
(
True
)
te_unpermute_bwd_input
=
unpermute_bwd_input
if
fp8
else
pytorch_unpermute_bwd_input
.
detach
()
te_unpermute_bwd_input
=
pytorch_unpermute_bwd_input
.
detach
()
te_unpermute_output
=
te_unpermute
(
te_unpermute_output
=
te_unpermute
(
te_unpermute_fwd_input
,
row_id_map
,
te_probs
,
map_type
=
"index"
te_unpermute_fwd_input
,
row_id_map
,
te_probs
,
map_type
=
"index"
...
@@ -352,12 +321,6 @@ def _test_permutation_index_map(
...
@@ -352,12 +321,6 @@ def _test_permutation_index_map(
###################################################################################################################################
###################################################################################################################################
tols
=
dtype_tols
(
te_dtype
)
tols
=
dtype_tols
(
te_dtype
)
if
fp8
:
te_permute_output_
=
te_permute_output
.
dequantize
(
dtype
=
torch
.
float32
)
te_permute_fwd_input_grad
=
te_permute_fwd_input
.
grad
.
dequantize
(
dtype
=
torch
.
float32
)
te_unpermute_output_
=
te_unpermute_output
.
dequantize
(
dtype
=
torch
.
float32
)
te_unpermute_fwd_input_grad
=
te_unpermute_fwd_input
.
grad
.
dequantize
(
dtype
=
torch
.
float32
)
else
:
te_permute_output_
=
te_permute_output
.
float
()
te_permute_output_
=
te_permute_output
.
float
()
te_permute_fwd_input_grad
=
te_permute_fwd_input
.
grad
.
float
()
te_permute_fwd_input_grad
=
te_permute_fwd_input
.
grad
.
float
()
te_unpermute_output_
=
te_unpermute_output
.
float
()
te_unpermute_output_
=
te_unpermute_output
.
float
()
...
@@ -487,7 +450,6 @@ def _test_permutation_mask_map(
...
@@ -487,7 +450,6 @@ def _test_permutation_mask_map(
f
" token:
{
num_tokens
}
hidden_size:
{
hidden_size
}
expert:
{
num_expert
}
topK:
{
topK
}
{
te_dtype
}
"
f
" token:
{
num_tokens
}
hidden_size:
{
hidden_size
}
expert:
{
num_expert
}
topK:
{
topK
}
{
te_dtype
}
"
)
)
fp8
=
False
# Convert TE dtypes to PyTorch dtypes
# Convert TE dtypes to PyTorch dtypes
if
te_dtype
==
tex
.
DType
.
kFloat32
:
if
te_dtype
==
tex
.
DType
.
kFloat32
:
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
@@ -495,46 +457,9 @@ def _test_permutation_mask_map(
...
@@ -495,46 +457,9 @@ def _test_permutation_mask_map(
dtype
=
torch
.
float16
dtype
=
torch
.
float16
elif
te_dtype
==
tex
.
DType
.
kBFloat16
:
elif
te_dtype
==
tex
.
DType
.
kBFloat16
:
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
elif
fp8_available
and
(
te_dtype
==
tex
.
DType
.
kFloat8E5M2
or
te_dtype
==
tex
.
DType
.
kFloat8E4M3
):
dtype
=
torch
.
uint8
fp8
=
True
else
:
else
:
pytest
.
skip
(
"Invalid dtype."
)
pytest
.
skip
(
"Invalid dtype."
)
if
fp8
:
permute_fwd_input
=
torch
.
rand
(
size
=
(
num_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
permute_bwd_input
=
torch
.
rand
(
size
=
(
num_out_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
unpermute_bwd_input
=
torch
.
rand
(
size
=
(
num_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
_permute_fwd_input_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
_permute_bwd_input_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
_unpermute_bwd_input_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
permute_fwd_input
=
_permute_fwd_input_quantizer
(
permute_fwd_input
)
permute_bwd_input
=
_permute_bwd_input_quantizer
(
permute_bwd_input
)
unpermute_bwd_input
=
_unpermute_bwd_input_quantizer
(
unpermute_bwd_input
)
pytorch_permute_fwd_input
=
permute_fwd_input
.
dequantize
(
dtype
=
torch
.
float16
)
pytorch_permute_bwd_input
=
permute_bwd_input
.
dequantize
(
dtype
=
torch
.
float16
)
pytorch_unpermute_bwd_input
=
unpermute_bwd_input
.
dequantize
(
dtype
=
torch
.
float16
)
else
:
pytorch_permute_fwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_permute_fwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_permute_bwd_input
=
torch
.
rand
((
num_out_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_permute_bwd_input
=
torch
.
rand
((
num_out_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_unpermute_bwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_unpermute_bwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
...
@@ -553,9 +478,6 @@ def _test_permutation_mask_map(
...
@@ -553,9 +478,6 @@ def _test_permutation_mask_map(
probs
=
torch
.
rand
(
num_tokens
,
num_expert
).
cuda
()
*
routing_map
probs
=
torch
.
rand
(
num_tokens
,
num_expert
).
cuda
()
*
routing_map
row_sums
=
probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
row_sums
=
probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
probs
=
probs
/
row_sums
probs
=
probs
/
row_sums
if
fp8
:
probs
=
probs
.
to
(
torch
.
float16
)
else
:
probs
=
probs
.
to
(
dtype
)
probs
=
probs
.
to
(
dtype
)
probs
.
requires_grad_
(
True
)
probs
.
requires_grad_
(
True
)
...
@@ -582,9 +504,9 @@ def _test_permutation_mask_map(
...
@@ -582,9 +504,9 @@ def _test_permutation_mask_map(
# TE Permutation
# TE Permutation
#
#
###################################################################################################################################
###################################################################################################################################
te_permute_fwd_input
=
permute_fwd_input
if
fp8
else
pytorch_permute_fwd_input
.
detach
()
te_permute_fwd_input
=
pytorch_permute_fwd_input
.
detach
()
te_permute_fwd_input
.
requires_grad_
(
True
)
te_permute_fwd_input
.
requires_grad_
(
True
)
te_permute_bwd_input
=
permute_bwd_input
if
fp8
else
pytorch_permute_bwd_input
.
detach
()
te_permute_bwd_input
=
pytorch_permute_bwd_input
.
detach
()
te_permute_output
,
row_id_map
=
te_permute
(
te_permute_output
,
row_id_map
=
te_permute
(
te_permute_fwd_input
,
routing_map
,
num_out_tokens
=
num_out_tokens
,
map_type
=
"mask"
te_permute_fwd_input
,
routing_map
,
num_out_tokens
=
num_out_tokens
,
map_type
=
"mask"
...
@@ -597,7 +519,7 @@ def _test_permutation_mask_map(
...
@@ -597,7 +519,7 @@ def _test_permutation_mask_map(
te_probs
.
requires_grad_
(
True
)
te_probs
.
requires_grad_
(
True
)
te_unpermute_fwd_input
=
te_permute_output
.
detach
()
te_unpermute_fwd_input
=
te_permute_output
.
detach
()
te_unpermute_fwd_input
.
requires_grad_
(
True
)
te_unpermute_fwd_input
.
requires_grad_
(
True
)
te_unpermute_bwd_input
=
unpermute_bwd_input
if
fp8
else
pytorch_unpermute_bwd_input
.
detach
()
te_unpermute_bwd_input
=
pytorch_unpermute_bwd_input
.
detach
()
te_unpermute_output
=
te_unpermute
(
te_unpermute_output
=
te_unpermute
(
te_unpermute_fwd_input
,
row_id_map
,
te_probs
,
restore_shape
,
map_type
=
"mask"
te_unpermute_fwd_input
,
row_id_map
,
te_probs
,
restore_shape
,
map_type
=
"mask"
...
@@ -611,12 +533,6 @@ def _test_permutation_mask_map(
...
@@ -611,12 +533,6 @@ def _test_permutation_mask_map(
###################################################################################################################################
###################################################################################################################################
tols
=
dtype_tols
(
te_dtype
)
tols
=
dtype_tols
(
te_dtype
)
if
fp8
:
te_permute_output_
=
te_permute_output
.
dequantize
(
dtype
=
torch
.
float32
)
te_permute_fwd_input_grad
=
te_permute_fwd_input
.
grad
.
dequantize
(
dtype
=
torch
.
float32
)
te_unpermute_output_
=
te_unpermute_output
.
dequantize
(
dtype
=
torch
.
float32
)
te_unpermute_fwd_input_grad
=
te_unpermute_fwd_input
.
grad
.
dequantize
(
dtype
=
torch
.
float32
)
else
:
te_permute_output_
=
te_permute_output
.
float
()
te_permute_output_
=
te_permute_output
.
float
()
te_permute_fwd_input_grad
=
te_permute_fwd_input
.
grad
.
float
()
te_permute_fwd_input_grad
=
te_permute_fwd_input
.
grad
.
float
()
te_unpermute_output_
=
te_unpermute_output
.
float
()
te_unpermute_output_
=
te_unpermute_output
.
float
()
...
@@ -730,6 +646,118 @@ def _test_permutation_mask_map(
...
@@ -730,6 +646,118 @@ def _test_permutation_mask_map(
print
(
f
"unpermute
\t
bwd: pytorch:
{
t1
:.
3
f
}
ms, TE:
{
t2
:.
3
f
}
ms"
)
print
(
f
"unpermute
\t
bwd: pytorch:
{
t1
:.
3
f
}
ms, TE:
{
t2
:.
3
f
}
ms"
)
def
_test_permutation_mask_map_fp8
(
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
recipe
,
):
if
topK
>
num_expert
:
pytest
.
skip
(
"topK should be smaller than the number of experts."
)
if
num_out_tokens
==
None
:
num_out_tokens
=
num_tokens
*
topK
if
recipe
.
delayed
():
quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
elif
recipe
.
float8_current_scaling
():
quantizer
=
Float8CurrentScalingQuantizer
(
fp8_dtype
=
te_dtype
,
device
=
torch
.
device
(
"cuda"
),
columnwise
=
False
,
)
elif
recipe
.
float8_block_scaling
():
quantizer
=
Float8BlockQuantizer
(
fp8_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
False
,
amax_epsilon
=
0.0
,
force_pow_2_scales
=
True
,
# Fp8 sub-channel a2a requires e8 scales
block_scaling_dim
=
1
,
# 1x128 scaling
)
elif
recipe
.
mxfp8
():
quantizer
=
MXFP8Quantizer
(
fp8_dtype
=
te_dtype
,
rowwise
=
True
,
columnwise
=
False
,
)
else
:
raise
ValueError
(
"Unsupported FP8 recipe"
)
permute_fwd_input
=
torch
.
rand
(
size
=
(
num_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
# Make an empty fp8 tensor
permute_fwd_input_fp8
=
quantizer
.
make_empty
(
permute_fwd_input
.
shape
,
dtype
=
permute_fwd_input
.
dtype
,
device
=
permute_fwd_input
.
device
,
)
# quantize the tensor
quantizer
.
update_quantized
(
permute_fwd_input
,
permute_fwd_input_fp8
)
if
recipe
.
float8_block_scaling
():
pytorch_permute_fwd_input
=
copy
.
deepcopy
(
permute_fwd_input_fp8
.
_rowwise_data
)
pytorch_permute_fwd_scale_input
=
copy
.
deepcopy
(
permute_fwd_input_fp8
.
_rowwise_scale_inv
.
T
.
contiguous
()
)
elif
recipe
.
mxfp8
():
pytorch_permute_fwd_input
=
copy
.
deepcopy
(
permute_fwd_input_fp8
.
_rowwise_data
)
pytorch_permute_fwd_scale_input
=
copy
.
deepcopy
(
permute_fwd_input_fp8
.
_rowwise_scale_inv
.
contiguous
()
)
else
:
pytorch_permute_fwd_input
=
copy
.
deepcopy
(
permute_fwd_input_fp8
.
_data
)
pytorch_permute_fwd_scale_input
=
None
_tmp_tensor
=
torch
.
zeros
((
num_tokens
*
num_expert
,))
_tmp_tensor
[:
int
(
num_out_tokens
)]
=
1.0
_tmp_idx
=
torch
.
randperm
(
num_tokens
*
num_expert
)
routing_map
=
torch
.
reshape
(
_tmp_tensor
[
_tmp_idx
],
(
num_tokens
,
num_expert
)).
bool
().
cuda
()
# PyTorch Permutaion
pytorch_permute_output
,
_
=
pytorch_permute_mask_map
(
pytorch_permute_fwd_input
,
routing_map
)
if
pytorch_permute_fwd_scale_input
is
not
None
:
pytorch_permute_scale_output
,
_
=
pytorch_permute_mask_map
(
pytorch_permute_fwd_scale_input
,
routing_map
)
# TE Permutation
permute_output
,
_
=
te_permute
(
permute_fwd_input_fp8
,
routing_map
,
num_out_tokens
=
num_out_tokens
,
map_type
=
"mask"
)
if
recipe
.
float8_block_scaling
():
te_permute_output
=
permute_output
.
_rowwise_data
te_permute_scale_output
=
permute_output
.
_rowwise_scale_inv
.
T
.
contiguous
()
elif
recipe
.
mxfp8
():
te_permute_output
=
permute_output
.
_rowwise_data
te_permute_scale_output
=
permute_output
.
_rowwise_scale_inv
.
contiguous
()
else
:
te_permute_output
=
permute_output
.
_data
te_permute_scale_output
=
None
# check the permute output
torch
.
testing
.
assert_close
(
pytorch_permute_output
,
te_permute_output
,
atol
=
0
,
rtol
=
0
,
)
if
recipe
.
float8_block_scaling
()
or
recipe
.
mxfp8
():
torch
.
testing
.
assert_close
(
pytorch_permute_scale_output
,
te_permute_scale_output
,
atol
=
0
,
rtol
=
0
,
)
def
_test_moe_chunk_sort
(
def
_test_moe_chunk_sort
(
te_dtype
,
te_dtype
,
num_tokens
,
num_tokens
,
...
@@ -743,7 +771,6 @@ def _test_moe_chunk_sort(
...
@@ -743,7 +771,6 @@ def _test_moe_chunk_sort(
f
" token:
{
num_tokens
}
hidden_size:
{
hidden_size
}
num_expert:
{
num_expert
}
tp_size:
{
tp_size
}
{
te_dtype
}
"
f
" token:
{
num_tokens
}
hidden_size:
{
hidden_size
}
num_expert:
{
num_expert
}
tp_size:
{
tp_size
}
{
te_dtype
}
"
)
)
fp8
=
False
# Convert TE dtypes to PyTorch dtypes
# Convert TE dtypes to PyTorch dtypes
if
te_dtype
==
tex
.
DType
.
kFloat32
:
if
te_dtype
==
tex
.
DType
.
kFloat32
:
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
@@ -751,32 +778,9 @@ def _test_moe_chunk_sort(
...
@@ -751,32 +778,9 @@ def _test_moe_chunk_sort(
dtype
=
torch
.
float16
dtype
=
torch
.
float16
elif
te_dtype
==
tex
.
DType
.
kBFloat16
:
elif
te_dtype
==
tex
.
DType
.
kBFloat16
:
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
elif
fp8_available
and
(
te_dtype
==
tex
.
DType
.
kFloat8E5M2
or
te_dtype
==
tex
.
DType
.
kFloat8E4M3
):
dtype
=
torch
.
uint8
fp8
=
True
else
:
else
:
pytest
.
skip
(
"Invalid dtype."
)
pytest
.
skip
(
"Invalid dtype."
)
if
fp8
:
fwd_input
=
torch
.
rand
(
size
=
(
num_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
bwd_input
=
torch
.
rand
(
size
=
(
num_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
_fwd_input_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
_bwd_input_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
fwd_input
=
_fwd_input_quantizer
.
quantize
(
fwd_input
)
bwd_input
=
_bwd_input_quantizer
.
quantize
(
bwd_input
)
pytorch_fwd_input
=
fwd_input
.
dequantize
(
dtype
=
torch
.
float16
)
pytorch_bwd_input
=
bwd_input
.
dequantize
(
dtype
=
torch
.
float16
)
else
:
pytorch_fwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_fwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_bwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_bwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
...
@@ -806,9 +810,9 @@ def _test_moe_chunk_sort(
...
@@ -806,9 +810,9 @@ def _test_moe_chunk_sort(
# TE Permutation
# TE Permutation
#
#
###################################################################################################################################
###################################################################################################################################
te_fwd_input
=
fwd_input
if
fp8
else
pytorch_fwd_input
.
detach
()
te_fwd_input
=
pytorch_fwd_input
.
detach
()
te_fwd_input
.
requires_grad_
(
True
)
te_fwd_input
.
requires_grad_
(
True
)
te_bwd_input
=
bwd_input
if
fp8
else
pytorch_bwd_input
.
detach
()
te_bwd_input
=
pytorch_bwd_input
.
detach
()
te_output
=
te_sort_chunks_by_index
(
te_fwd_input
,
split_sizes_cuda
,
sorted_idxs_cuda
)
te_output
=
te_sort_chunks_by_index
(
te_fwd_input
,
split_sizes_cuda
,
sorted_idxs_cuda
)
te_output
.
backward
(
te_bwd_input
,
retain_graph
=
True
)
te_output
.
backward
(
te_bwd_input
,
retain_graph
=
True
)
...
@@ -820,10 +824,6 @@ def _test_moe_chunk_sort(
...
@@ -820,10 +824,6 @@ def _test_moe_chunk_sort(
###################################################################################################################################
###################################################################################################################################
tols
=
dtype_tols
(
te_dtype
)
tols
=
dtype_tols
(
te_dtype
)
if
fp8
:
te_output_
=
te_output
.
dequantize
(
dtype
=
torch
.
float32
)
te_fwd_input_grad
=
te_fwd_input
.
grad
.
dequantize
(
dtype
=
torch
.
float32
)
else
:
te_output_
=
te_output
.
float
()
te_output_
=
te_output
.
float
()
te_fwd_input_grad
=
te_fwd_input
.
grad
.
float
()
te_fwd_input_grad
=
te_fwd_input
.
grad
.
float
()
...
@@ -899,7 +899,6 @@ def _test_permutation_mask_map_alongside_probs(
...
@@ -899,7 +899,6 @@ def _test_permutation_mask_map_alongside_probs(
f
" token:
{
num_tokens
}
hidden_size:
{
hidden_size
}
expert:
{
num_expert
}
topK:
{
topK
}
{
te_dtype
}
"
f
" token:
{
num_tokens
}
hidden_size:
{
hidden_size
}
expert:
{
num_expert
}
topK:
{
topK
}
{
te_dtype
}
"
)
)
fp8
=
False
# Convert TE dtypes to PyTorch dtypes
# Convert TE dtypes to PyTorch dtypes
if
te_dtype
==
tex
.
DType
.
kFloat32
:
if
te_dtype
==
tex
.
DType
.
kFloat32
:
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
@@ -907,36 +906,9 @@ def _test_permutation_mask_map_alongside_probs(
...
@@ -907,36 +906,9 @@ def _test_permutation_mask_map_alongside_probs(
dtype
=
torch
.
float16
dtype
=
torch
.
float16
elif
te_dtype
==
tex
.
DType
.
kBFloat16
:
elif
te_dtype
==
tex
.
DType
.
kBFloat16
:
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
elif
fp8_available
and
(
te_dtype
==
tex
.
DType
.
kFloat8E5M2
or
te_dtype
==
tex
.
DType
.
kFloat8E4M3
):
dtype
=
torch
.
uint8
fp8
=
True
else
:
else
:
pytest
.
skip
(
"Invalid dtype."
)
pytest
.
skip
(
"Invalid dtype."
)
if
fp8
:
permute_fwd_input
=
torch
.
rand
(
size
=
(
num_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
unpermute_bwd_input
=
torch
.
rand
(
size
=
(
num_tokens
,
hidden_size
),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
_permute_fwd_input_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
_unpermute_bwd_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
permute_fwd_input
=
_permute_fwd_input_quantizer
.
quantize
(
permute_fwd_input
)
unpermute_bwd_input
=
_unpermute_bwd_quantizer
.
quantize
(
unpermute_bwd_input
)
pytorch_permute_fwd_input
=
permute_fwd_input
.
dequantize
(
dtype
=
torch
.
float16
)
pytorch_unpermute_bwd_input
=
unpermute_bwd_input
.
dequantize
(
dtype
=
torch
.
float16
)
else
:
pytorch_permute_fwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_permute_fwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_unpermute_bwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
pytorch_unpermute_bwd_input
=
torch
.
rand
((
num_tokens
,
hidden_size
),
dtype
=
dtype
).
cuda
()
...
@@ -952,9 +924,6 @@ def _test_permutation_mask_map_alongside_probs(
...
@@ -952,9 +924,6 @@ def _test_permutation_mask_map_alongside_probs(
probs
=
torch
.
rand
(
num_tokens
,
num_expert
).
cuda
()
*
routing_map
probs
=
torch
.
rand
(
num_tokens
,
num_expert
).
cuda
()
*
routing_map
row_sums
=
probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
row_sums
=
probs
.
sum
(
dim
=
1
,
keepdim
=
True
)
probs
=
probs
/
row_sums
probs
=
probs
/
row_sums
if
fp8
:
probs
=
probs
.
to
(
torch
.
float16
)
else
:
probs
=
probs
.
to
(
dtype
)
probs
=
probs
.
to
(
dtype
)
probs
.
requires_grad_
(
True
)
probs
.
requires_grad_
(
True
)
...
@@ -1006,13 +975,12 @@ def _test_permutation_mask_map_alongside_probs(
...
@@ -1006,13 +975,12 @@ def _test_permutation_mask_map_alongside_probs(
# TE Permutation
# TE Permutation
#
#
###################################################################################################################################
###################################################################################################################################
te_permute_fwd_input
=
permute_fwd_input
if
fp8
else
pytorch_permute_fwd_input
.
detach
()
te_permute_fwd_input
=
pytorch_permute_fwd_input
.
detach
()
te_permute_fwd_input
.
requires_grad_
(
True
)
te_permute_fwd_input
.
requires_grad_
(
True
)
te_unpermute_bwd_input
=
unpermute_bwd_input
if
fp8
else
pytorch_unpermute_bwd_input
.
detach
()
te_unpermute_bwd_input
=
pytorch_unpermute_bwd_input
.
detach
()
te_probs
=
probs
.
detach
()
te_probs
=
probs
.
detach
()
te_probs
.
requires_grad_
(
True
)
te_probs
.
requires_grad_
(
True
)
print
(
te_probs
.
shape
)
te_permute_output
,
te_permuted_probs
,
row_id_map
=
te_permute_with_probs
(
te_permute_output
,
te_permuted_probs
,
row_id_map
=
te_permute_with_probs
(
te_permute_fwd_input
,
te_permute_fwd_input
,
...
@@ -1020,25 +988,12 @@ def _test_permutation_mask_map_alongside_probs(
...
@@ -1020,25 +988,12 @@ def _test_permutation_mask_map_alongside_probs(
routing_map
,
routing_map
,
num_out_tokens
=
num_out_tokens
,
num_out_tokens
=
num_out_tokens
,
)
)
print
(
te_permuted_probs
.
shape
)
te_permute_output
,
te_permuted_probs
=
te_sort_chunks_by_index_with_probs
(
te_permute_output
,
te_permuted_probs
=
te_sort_chunks_by_index_with_probs
(
te_permute_output
,
te_permuted_probs
,
split_sizes_cuda
,
sorted_idxs_cuda
te_permute_output
,
te_permuted_probs
,
split_sizes_cuda
,
sorted_idxs_cuda
)
)
if
fp8
:
_permute_output_quantizer
=
Float8Quantizer
(
scale
=
torch
.
full
([
1
],
1.0
).
cuda
().
squeeze
(),
amax
=
torch
.
full
([
1
],
1.0
).
cuda
(),
fp8_dtype
=
te_dtype
,
)
te_permute_output
=
te_permute_output
.
dequantize
(
dtype
=
torch
.
float32
)
te_permute_output
=
te_permute_output
*
te_permuted_probs
.
unsqueeze
(
-
1
)
te_permute_output
=
_permute_output_quantizer
.
quantize
(
te_permute_output
)
else
:
te_permute_output_dtype
=
te_permute_output
.
dtype
te_permute_output_dtype
=
te_permute_output
.
dtype
print
(
te_permute_output
.
shape
)
print
(
te_permuted_probs
.
shape
)
te_permute_output
=
te_permute_output
*
te_permuted_probs
.
unsqueeze
(
-
1
)
te_permute_output
=
te_permute_output
*
te_permuted_probs
.
unsqueeze
(
-
1
)
te_permute_output
=
te_permute_output
.
to
(
dtype
=
te_permute_output_dtype
)
te_permute_output
=
te_permute_output
.
to
(
dtype
=
te_permute_output_dtype
)
...
@@ -1058,11 +1013,6 @@ def _test_permutation_mask_map_alongside_probs(
...
@@ -1058,11 +1013,6 @@ def _test_permutation_mask_map_alongside_probs(
tols
=
dtype_tols
(
te_dtype
)
tols
=
dtype_tols
(
te_dtype
)
if
fp8
:
# backward of dequantize is in high precision
te_permute_fwd_input_grad
=
te_permute_fwd_input
.
grad
.
float
()
te_unpermute_output_
=
te_unpermute_output
.
dequantize
(
dtype
=
torch
.
float32
)
else
:
te_permute_fwd_input_grad
=
te_permute_fwd_input
.
grad
.
float
()
te_permute_fwd_input_grad
=
te_permute_fwd_input
.
grad
.
float
()
te_unpermute_output_
=
te_unpermute_output
.
float
()
te_unpermute_output_
=
te_unpermute_output
.
float
()
...
@@ -1228,6 +1178,16 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype):
...
@@ -1228,6 +1178,16 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype):
# Only run FP8 tests on H100.
# Only run FP8 tests on H100.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
fp8_recipes
=
[
recipe
.
MXFP8BlockScaling
(),
recipe
.
DelayedScaling
(),
recipe
.
Float8CurrentScaling
(),
recipe
.
Float8BlockScaling
(),
]
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
...
@@ -1237,36 +1197,7 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
...
@@ -1237,36 +1197,7 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
def
test_permutation_index_map_fp8
(
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
):
with_probs
=
True
BENCHMARK
=
False
_test_permutation_index_map
(
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
with_probs
=
with_probs
,
BENCHMARK
=
BENCHMARK
,
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
def
test_permutation_mask_map_fp8
(
def
test_permutation_mask_map_fp8
(
te_dtype
,
te_dtype
,
num_tokens
,
num_tokens
,
...
@@ -1274,47 +1205,21 @@ def test_permutation_mask_map_fp8(
...
@@ -1274,47 +1205,21 @@ def test_permutation_mask_map_fp8(
hidden_size
,
hidden_size
,
topK
,
topK
,
num_out_tokens
,
num_out_tokens
,
recipe
,
):
):
with_probs
=
True
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
BENCHMARK
=
False
pytest
.
skip
(
reason_for_no_mxfp8
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
_test_permutation_mask_map
(
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
with_probs
=
with_probs
,
BENCHMARK
=
BENCHMARK
,
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
_test_permutation_mask_map_fp8
(
@
pytest
.
mark
.
parametrize
(
"te_dtype"
,
[
tex
.
DType
.
kFloat8E4M3
,
tex
.
DType
.
kFloat8E5M2
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"num_expert"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
4096
])
@
pytest
.
mark
.
parametrize
(
"topK"
,
[
1
,
2
,
5
])
@
pytest
.
mark
.
parametrize
(
"num_out_tokens"
,
[
None
,
2039
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
,
8
])
def
test_permutation_mask_map_alongside_probs_fp8
(
te_dtype
,
num_tokens
,
num_expert
,
hidden_size
,
topK
,
num_out_tokens
,
tp_size
,
):
_test_permutation_mask_map_alongside_probs
(
te_dtype
=
te_dtype
,
te_dtype
=
te_dtype
,
num_tokens
=
num_tokens
,
num_tokens
=
num_tokens
,
num_expert
=
num_expert
,
num_expert
=
num_expert
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
topK
=
topK
,
topK
=
topK
,
num_out_tokens
=
num_out_tokens
,
num_out_tokens
=
num_out_tokens
,
tp_size
=
tp_siz
e
,
recipe
=
recip
e
,
)
)
...
@@ -1415,11 +1320,9 @@ def test_permutation_single_case():
...
@@ -1415,11 +1320,9 @@ def test_permutation_single_case():
# te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat16
# te_dtype = tex.DType.kFloat16
# te_dtype = tex.DType.kBFloat16
te_dtype
=
tex
.
DType
.
kBFloat16
te_dtype
=
tex
.
DType
.
kFloat8E5M2
# te_dtype = tex.DType.kFloat8E4M3
num_tokens
=
1
0
num_tokens
=
1
2
num_expert
=
4
num_expert
=
4
hidden_size
=
16
hidden_size
=
16
topK
=
2
topK
=
2
...
...
tests/pytorch/test_sanity.py
View file @
ab3e5a92
...
@@ -43,10 +43,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
...
@@ -43,10 +43,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
Float8CurrentScalingQuantizer
,
Float8CurrentScalingQuantizer
,
)
)
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
transformer_engine.pytorch.distributed
import
checkpoint
from
test_numerics
import
reset_rng_states
,
dtype_tols
from
test_numerics
import
reset_rng_states
,
dtype_tols
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
...
@@ -113,6 +117,7 @@ fp8_recipes = [
...
@@ -113,6 +117,7 @@ fp8_recipes = [
None
,
# Test non-FP8
None
,
# Test non-FP8
recipe
.
MXFP8BlockScaling
(),
# Test default
recipe
.
MXFP8BlockScaling
(),
# Test default
recipe
.
Float8CurrentScaling
(),
# Test default
recipe
.
Float8CurrentScaling
(),
# Test default
recipe
.
Float8BlockScaling
(),
# Test default
recipe
.
DelayedScaling
(),
# Test default
recipe
.
DelayedScaling
(),
# Test default
recipe
.
DelayedScaling
(
# Test most_recent algo
recipe
.
DelayedScaling
(
# Test most_recent algo
amax_history_len
=
16
,
amax_history_len
=
16
,
...
@@ -446,6 +451,8 @@ def test_sanity_layernorm_linear(
...
@@ -446,6 +451,8 @@ def test_sanity_layernorm_linear(
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -477,6 +484,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
...
@@ -477,6 +484,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -509,6 +518,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
...
@@ -509,6 +518,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -550,10 +561,10 @@ def test_sanity_grouped_linear(
...
@@ -550,10 +561,10 @@ def test_sanity_grouped_linear(
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
mxfp8
():
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
"Grouped linear does not support MXFP8"
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_
current
_scaling
():
if
fp8_recipe
.
float8_
block
_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
"Grouped linear does not support FP8 current
scaling
"
)
pytest
.
skip
(
reason_for_no_fp8_block_
scaling
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
...
@@ -597,6 +608,8 @@ def test_sanity_layernorm_mlp(
...
@@ -597,6 +608,8 @@ def test_sanity_layernorm_mlp(
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -647,6 +660,8 @@ def test_sanity_gpt(
...
@@ -647,6 +660,8 @@ def test_sanity_gpt(
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -714,6 +729,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
...
@@ -714,6 +729,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -773,6 +790,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
...
@@ -773,6 +790,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -830,6 +849,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -830,6 +849,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad):
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -865,6 +886,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -865,6 +886,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad):
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -903,6 +926,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
...
@@ -903,6 +926,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -944,6 +969,8 @@ def test_sanity_gradient_accumulation_fusion(
...
@@ -944,6 +969,8 @@ def test_sanity_gradient_accumulation_fusion(
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
...
@@ -990,8 +1017,12 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
...
@@ -990,8 +1017,12 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
if
fp8_recipe
is
not
None
:
if
fp8_recipe
is
not
None
:
if
not
fp8_available
:
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
if
fp8_recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_recipe
.
float8_block_scaling
():
pytest
.
skip
(
"cuda graph not supported for float8_block_scaling recipe"
)
if
not
config
.
is_fp8_supported
():
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
pytest
.
skip
(
"Model config does not support FP8"
)
...
@@ -1266,3 +1297,31 @@ def test_fp8_model_init_high_precision_init_val():
...
@@ -1266,3 +1297,31 @@ def test_fp8_model_init_high_precision_init_val():
assert
not
hasattr
(
assert
not
hasattr
(
weight
,
"._high_precision_init_val"
weight
,
"._high_precision_init_val"
),
"clear_high_precision_init_val() not work"
),
"clear_high_precision_init_val() not work"
def
test_sanity_checkpointing_on_callables
():
"""Test that TE checkpointing works correctly on callable modules."""
# torch.autograf.function
class
MyFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
inp
):
return
inp
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
grad_output
module
=
MyFunction
.
apply
inp
=
torch
.
randn
(
10
,
10
,
device
=
"cuda"
,
requires_grad
=
True
)
out_checkpoint
=
checkpoint
(
module
,
inp
)
out_checkpoint
.
sum
().
backward
()
grad_checkpoint
=
inp
.
grad
out_standard
=
module
(
inp
)
out_standard
.
sum
().
backward
()
grad_standard
=
inp
.
grad
# Assert that gradients are the same
torch
.
testing
.
assert_close
(
grad_checkpoint
,
grad_standard
)
transformer_engine/common/CMakeLists.txt
View file @
ab3e5a92
...
@@ -116,6 +116,8 @@ if(USE_CUDA)
...
@@ -116,6 +116,8 @@ if(USE_CUDA)
transpose/cast_transpose_fusion.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/multi_cast_transpose.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu
activation/gelu.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_max512_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
fused_attn/fused_attn_f16_arbitrary_seqlen.cu
...
@@ -161,6 +163,8 @@ else()
...
@@ -161,6 +163,8 @@ else()
transpose/cast_transpose_fusion.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
transpose/multi_cast_transpose.cu
# transpose/quantize_transpose_square_blockwise.cu
# transpose/quantize_transpose_vector_blockwise.cu
activation/gelu.cu
activation/gelu.cu
activation/relu.cu
activation/relu.cu
activation/swiglu.cu
activation/swiglu.cu
...
@@ -271,6 +275,20 @@ if (NVTE_UB_WITH_MPI)
...
@@ -271,6 +275,20 @@ if (NVTE_UB_WITH_MPI)
target_compile_definitions
(
transformer_engine PUBLIC NVTE_UB_WITH_MPI
)
target_compile_definitions
(
transformer_engine PUBLIC NVTE_UB_WITH_MPI
)
endif
()
endif
()
option
(
NVTE_ENABLE_NVSHMEM
"Compile with NVSHMEM library"
OFF
)
if
(
NVTE_ENABLE_NVSHMEM
)
add_subdirectory
(
nvshmem_api
)
target_link_libraries
(
transformer_engine PUBLIC nvshmemapi
)
target_include_directories
(
transformer_engine PUBLIC
${
NVSHMEMAPI_INCLUDE_DIR
}
)
endif
()
option
(
NVTE_ENABLE_NVSHMEM
"Compile with NVSHMEM library"
OFF
)
if
(
NVTE_ENABLE_NVSHMEM
)
add_subdirectory
(
nvshmem_api
)
target_link_libraries
(
transformer_engine PUBLIC nvshmemapi
)
target_include_directories
(
transformer_engine PUBLIC
${
NVSHMEMAPI_INCLUDE_DIR
}
)
endif
()
if
(
USE_CUDA
)
if
(
USE_CUDA
)
# Hack to enable dynamic loading in cuDNN frontend
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions
(
transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
)
target_compile_definitions
(
transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING
)
...
...
transformer_engine/common/activation/activation_template.h
View file @
ab3e5a92
...
@@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
...
@@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
OP
>
(
input
,
grad
,
nullptr
,
output
,
dbias
,
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
OP
>
(
input
,
grad
,
output
,
dbias
,
workspace
,
workspace
,
stream
);
nullptr
,
stream
);
}
}
template
<
typename
ComputeType
,
typename
Param
,
ComputeType
(
*
OP
)(
ComputeType
,
const
Param
&
)>
template
<
typename
ComputeType
,
typename
Param
,
ComputeType
(
*
OP
)(
ComputeType
,
const
Param
&
)>
...
@@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
...
@@ -46,8 +46,8 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
constexpr
NVTETensor
dbias
=
nullptr
;
constexpr
NVTETensor
dbias
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
OP
>
(
input
,
grad
,
nullptr
,
output
,
dbias
,
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
OP
>
(
input
,
grad
,
output
,
dbias
,
workspace
,
workspace
,
stream
);
nullptr
,
stream
);
}
}
template
<
typename
ComputeType
,
typename
Param
,
ComputeType
(
*
ActOP
)(
ComputeType
,
const
Param
&
)>
template
<
typename
ComputeType
,
typename
Param
,
ComputeType
(
*
ActOP
)(
ComputeType
,
const
Param
&
)>
...
...
transformer_engine/common/common.h
View file @
ab3e5a92
...
@@ -83,8 +83,8 @@ struct SimpleTensor {
...
@@ -83,8 +83,8 @@ struct SimpleTensor {
SimpleTensor
()
:
SimpleTensor
(
nullptr
,
{},
DType
::
kFloat32
)
{}
SimpleTensor
()
:
SimpleTensor
(
nullptr
,
{},
DType
::
kFloat32
)
{}
operator
NVTEBasicTensor
()
const
{
operator
NVTEBasicTensor
()
const
{
const
NVTEShape
shape
=
{
this
->
shape
.
data
(),
this
->
shape
.
size
()};
return
{
dptr
,
static_cast
<
NVTEDType
>
(
dtype
),
return
{
dptr
,
static_cast
<
NVTEDType
>
(
dtype
),
shape
};
nvte_make_shape
(
this
->
shape
.
data
(),
this
->
shape
.
size
())
};
}
}
int
numel
()
const
{
int
numel
()
const
{
...
@@ -104,6 +104,7 @@ struct Tensor {
...
@@ -104,6 +104,7 @@ struct Tensor {
SimpleTensor
scale_inv
;
SimpleTensor
scale_inv
;
SimpleTensor
columnwise_scale_inv
;
SimpleTensor
columnwise_scale_inv
;
public:
NVTEScalingMode
scaling_mode
;
NVTEScalingMode
scaling_mode
;
Tensor
()
Tensor
()
...
@@ -165,6 +166,28 @@ struct Tensor {
...
@@ -165,6 +166,28 @@ struct Tensor {
return
data
.
shape
;
return
data
.
shape
;
}
}
break
;
break
;
case
NVTE_BLOCK_SCALING_1D
:
case
NVTE_BLOCK_SCALING_2D
:
{
if
(
!
has_data
()
&&
has_columnwise_data
())
{
std
::
vector
<
size_t
>
shape
;
size_t
ndim
=
columnwise_data
.
shape
.
size
();
shape
.
reserve
(
ndim
);
for
(
size_t
i
=
0
;
i
+
1
<
ndim
;
++
i
)
{
shape
.
push_back
(
columnwise_data
.
shape
[
i
+
1
]);
}
if
(
ndim
>
0
)
{
shape
.
push_back
(
columnwise_data
.
shape
[
0
]);
}
return
shape
;
}
else
{
// NOTE: We may have removed the data pointer from
// data by setting usage. In that case, we return
// the non-null shape. It is our best guess at the most
// recent shape.
return
data
.
shape
;
}
break
;
}
default:
default:
NVTE_ERROR
(
"Cannot parse tensor shape with scaling mode
\"
"
,
to_string
(
scaling_mode
),
"
\"
"
);
NVTE_ERROR
(
"Cannot parse tensor shape with scaling mode
\"
"
,
to_string
(
scaling_mode
),
"
\"
"
);
return
{};
return
{};
...
@@ -205,10 +228,12 @@ struct Tensor {
...
@@ -205,10 +228,12 @@ struct Tensor {
struct
QuantizationConfig
{
struct
QuantizationConfig
{
bool
force_pow_2_scales
=
false
;
bool
force_pow_2_scales
=
false
;
float
amax_epsilon
=
0.0
f
;
float
amax_epsilon
=
0.0
f
;
NVTETensor
noop_tensor
=
nullptr
;
static
constexpr
size_t
attr_sizes
[]
=
{
static
constexpr
size_t
attr_sizes
[]
=
{
sizeof
(
bool
),
// force_pow_2_scales
sizeof
(
bool
),
// force_pow_2_scales
sizeof
(
float
)
// amax_epsilon
sizeof
(
float
),
// amax_epsilon
sizeof
(
NVTETensor
)
// noop_tensor
};
};
};
};
...
@@ -264,6 +289,36 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
...
@@ -264,6 +289,36 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
#endif
#endif
#undef TRANSFORMER_ENGINE_TYPE_NAME
#undef TRANSFORMER_ENGINE_TYPE_NAME
template
<
typename
T
>
struct
TypeExtrema
;
template
<
>
struct
TypeExtrema
<
fp8e4m3
>
{
static
constexpr
float
max
=
448.0
f
;
};
template
<
>
struct
TypeExtrema
<
fp8e5m2
>
{
static
constexpr
float
max
=
57344.0
f
;
};
template
<
>
struct
TypeExtrema
<
bf16
>
{
// Hex float format of 1.(7 bits of 1) * 2 ^ 127
static
constexpr
float
max
=
0x1
.
FEp127
;
};
template
<
>
struct
TypeExtrema
<
fp16
>
{
// Hex float format of 1.(10 bits of 1) * 2 ^ 15
static
constexpr
float
max
=
0x1
.
FFCp15
;
};
template
<
typename
T
>
struct
TypeExtrema
{
static
constexpr
float
max
=
std
::
numeric_limits
<
T
>::
max
();
};
}
// namespace detail
}
// namespace detail
template
<
typename
T
>
template
<
typename
T
>
...
@@ -294,6 +349,7 @@ struct TypeInfo {
...
@@ -294,6 +349,7 @@ struct TypeInfo {
constexpr
static
DType
dtype
=
getType
<
T
>
();
constexpr
static
DType
dtype
=
getType
<
T
>
();
constexpr
static
size_t
size
=
sizeof
(
T
);
constexpr
static
size_t
size
=
sizeof
(
T
);
constexpr
static
float
max_finite_value
=
detail
::
TypeExtrema
<
T
>::
max
;
constexpr
static
const
char
*
name
=
detail
::
type_name
<
T
>
();
constexpr
static
const
char
*
name
=
detail
::
type_name
<
T
>
();
};
};
...
...
transformer_engine/common/fused_rope/fused_rope.cu
View file @
ab3e5a92
...
@@ -16,10 +16,11 @@ namespace transformer_engine {
...
@@ -16,10 +16,11 @@ namespace transformer_engine {
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__device__
void
fused_rope_block_forward
(
const
scalar_t
*
src
,
const
float
*
freqs
,
scalar_t
*
dst
,
__device__
void
fused_rope_block_forward
(
const
scalar_t
*
src
,
const
float
*
freqs
,
scalar_t
*
dst
,
const
int
s_id
,
const
int
offset_block
,
const
bool
interleaved
,
const
int
s_id
,
const
int
offset_block_dst
,
const
int
h
,
const
int
d
,
const
int
offset_block
,
const
int
offset_block_dst
,
const
int
d2
,
const
int
stride_h
,
const
int
stride_d
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_h
,
const
int
o_stride_h
,
const
int
o_stride_d
)
{
const
int
stride_d
,
const
int
o_stride_h
,
const
int
o_stride_d
)
{
#pragma unroll
#pragma unroll
for
(
int
d_id
=
threadIdx
.
x
;
d_id
<
d2
;
d_id
+=
blockDim
.
x
)
{
for
(
int
d_id
=
threadIdx
.
x
;
d_id
<
d2
;
d_id
+=
blockDim
.
x
)
{
float
v_cos
,
v_sin
;
float
v_cos
,
v_sin
;
...
@@ -29,9 +30,18 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
...
@@ -29,9 +30,18 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
int
offset_src
=
offset_block
+
h_id
*
stride_h
+
d_id
*
stride_d
;
int
offset_src
=
offset_block
+
h_id
*
stride_h
+
d_id
*
stride_d
;
int
offset_dst
=
offset_block_dst
+
h_id
*
o_stride_h
+
d_id
*
o_stride_d
;
int
offset_dst
=
offset_block_dst
+
h_id
*
o_stride_h
+
d_id
*
o_stride_d
;
float
v_src
=
src
[
offset_src
];
float
v_src
=
src
[
offset_src
];
float
v_src_rotate
=
(
d_id
+
d2
/
2
<
d2
)
float
v_src_rotate
;
if
(
!
interleaved
)
{
v_src_rotate
=
(
d_id
+
d2
/
2
<
d2
)
?
-
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
)
*
stride_d
])
?
-
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
)
*
stride_d
])
:
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
-
d2
)
*
stride_d
]);
:
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
-
d2
)
*
stride_d
]);
}
else
{
v_src_rotate
=
(
d_id
%
2
==
0
)
// d_id + 1
?
-
static_cast
<
float
>
(
src
[
offset_src
+
stride_d
])
// d_id - 1
:
static_cast
<
float
>
(
src
[
offset_src
-
stride_d
]);
}
dst
[
offset_dst
]
=
v_src
*
v_cos
+
v_src_rotate
*
v_sin
;
dst
[
offset_dst
]
=
v_src
*
v_cos
+
v_src_rotate
*
v_sin
;
}
}
}
}
...
@@ -52,22 +62,39 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
...
@@ -52,22 +62,39 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__device__
void
fused_rope_block_backward
(
const
scalar_t
*
src
,
const
float
*
freqs
,
scalar_t
*
dst
,
__device__
void
fused_rope_block_backward
(
const
scalar_t
*
src
,
const
float
*
freqs
,
scalar_t
*
dst
,
const
int
s_id
,
const
int
offset_block
,
const
bool
interleaved
,
const
int
s_id
,
const
int
offset_block_dst
,
const
int
h
,
const
int
d
,
const
int
offset_block
,
const
int
offset_block_dst
,
const
int
d2
,
const
int
stride_h
,
const
int
stride_d
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_h
,
const
int
o_stride_d
)
{
const
int
o_stride_h
,
const
int
o_stride_d
)
{
#pragma unroll
#pragma unroll
for
(
int
d_id
=
threadIdx
.
x
;
d_id
<
d2
;
d_id
+=
blockDim
.
x
)
{
for
(
int
d_id
=
threadIdx
.
x
;
d_id
<
d2
;
d_id
+=
blockDim
.
x
)
{
float
v_cos
=
cosf
(
freqs
[
s_id
*
d2
+
d_id
]);
float
v_cos
=
cosf
(
freqs
[
s_id
*
d2
+
d_id
]);
float
v_sin
=
(
d_id
+
d2
/
2
<
d2
)
?
sinf
(
freqs
[
s_id
*
d2
+
d_id
+
d2
/
2
])
float
v_sin
;
if
(
!
interleaved
)
{
v_sin
=
(
d_id
+
d2
/
2
<
d2
)
?
sinf
(
freqs
[
s_id
*
d2
+
d_id
+
d2
/
2
])
:
-
sinf
(
freqs
[
s_id
*
d2
+
d_id
+
d2
/
2
-
d2
]);
:
-
sinf
(
freqs
[
s_id
*
d2
+
d_id
+
d2
/
2
-
d2
]);
}
else
{
v_sin
=
(
d_id
%
2
==
0
)
?
sinf
(
freqs
[
s_id
*
d2
+
d_id
+
1
])
:
-
sinf
(
freqs
[
s_id
*
d2
+
d_id
-
1
]);
}
#pragma unroll
#pragma unroll
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
for
(
int
h_id
=
threadIdx
.
y
;
h_id
<
h
;
h_id
+=
blockDim
.
y
)
{
int
offset_src
=
offset_block
+
h_id
*
stride_h
+
d_id
*
stride_d
;
int
offset_src
=
offset_block
+
h_id
*
stride_h
+
d_id
*
stride_d
;
int
offset_dst
=
offset_block_dst
+
h_id
*
o_stride_h
+
d_id
*
o_stride_d
;
int
offset_dst
=
offset_block_dst
+
h_id
*
o_stride_h
+
d_id
*
o_stride_d
;
float
v_src
=
src
[
offset_src
];
float
v_src
=
src
[
offset_src
];
float
v_src_rotate
=
(
d_id
+
d2
/
2
<
d2
)
?
src
[
offset_src
+
(
d2
/
2
)
*
stride_d
]
float
v_src_rotate
;
:
src
[
offset_src
+
(
d2
/
2
-
d2
)
*
stride_d
];
if
(
!
interleaved
)
{
v_src_rotate
=
(
d_id
+
d2
/
2
<
d2
)
?
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
)
*
stride_d
])
:
static_cast
<
float
>
(
src
[
offset_src
+
(
d2
/
2
-
d2
)
*
stride_d
]);
}
else
{
v_src_rotate
=
(
d_id
%
2
==
0
)
// d_id + 1
?
static_cast
<
float
>
(
src
[
offset_src
+
stride_d
])
// d_id - 1
:
static_cast
<
float
>
(
src
[
offset_src
-
stride_d
]);
}
dst
[
offset_dst
]
=
v_src
*
v_cos
+
v_src_rotate
*
v_sin
;
dst
[
offset_dst
]
=
v_src
*
v_cos
+
v_src_rotate
*
v_sin
;
}
}
}
}
...
@@ -87,51 +114,33 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
...
@@ -87,51 +114,33 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
fused_rope_forward_kernel
(
const
scalar_t
*
src
,
const
float
*
freqs
,
scalar_t
*
dst
,
__global__
void
fused_rope_forward_kernel
(
const
scalar_t
*
src
,
const
int
*
cu_seqlens
,
const
float
*
freqs
,
scalar_t
*
dst
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s
,
const
int
stride_b
,
const
int
stride_s
_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_s
,
const
int
o_stride_b
,
const
int
o_stride_s_or_t
,
const
int
o_stride_b
,
const
int
o_stride_h
,
const
int
o_stride_d
)
{
int
s_id
=
blockIdx
.
x
,
b_id
=
blockIdx
.
y
;
int
offset_block
=
s_id
*
stride_s
+
b_id
*
stride_b
;
int
offset_block_dst
=
s_id
*
o_stride_s
+
b_id
*
o_stride_b
;
fused_rope_block_forward
(
src
,
freqs
,
dst
,
s_id
,
offset_block
,
offset_block_dst
,
h
,
d
,
d2
,
stride_h
,
stride_d
,
o_stride_h
,
o_stride_d
);
}
template
<
typename
scalar_t
>
__global__
void
fused_rope_backward_kernel
(
const
scalar_t
*
src
,
const
float
*
freqs
,
scalar_t
*
dst
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_s
,
const
int
o_stride_b
,
const
int
o_stride_h
,
const
int
o_stride_d
)
{
int
s_id
=
blockIdx
.
x
,
b_id
=
blockIdx
.
y
;
int
offset_block
=
s_id
*
stride_s
+
b_id
*
stride_b
;
int
offset_block_dst
=
s_id
*
o_stride_s
+
b_id
*
o_stride_b
;
fused_rope_block_backward
(
src
,
freqs
,
dst
,
s_id
,
offset_block
,
offset_block_dst
,
h
,
d
,
d2
,
stride_h
,
stride_d
,
o_stride_h
,
o_stride_d
);
}
template
<
typename
scalar_t
>
__global__
void
fused_rope_thd_forward_kernel
(
const
scalar_t
*
src
,
const
int
*
cu_seqlens
,
const
float
*
freqs
,
scalar_t
*
dst
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_t
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_t
,
const
int
o_stride_h
,
const
int
o_stride_d
)
{
const
int
o_stride_h
,
const
int
o_stride_d
)
{
int
s_id
=
blockIdx
.
x
,
b_id
=
blockIdx
.
y
;
int
s_id
=
blockIdx
.
x
,
b_id
=
blockIdx
.
y
;
int
offset_block
,
offset_block_dst
;
int
cur_seqlens
;
if
(
cu_seqlens
!=
nullptr
)
{
// THD
int
start
=
cu_seqlens
[
b_id
]
/
cp_size
;
int
start
=
cu_seqlens
[
b_id
]
/
cp_size
;
int
end
=
cu_seqlens
[
b_id
+
1
]
/
cp_size
;
int
end
=
cu_seqlens
[
b_id
+
1
]
/
cp_size
;
int
t_id
=
s_id
+
start
;
int
t_id
=
s_id
+
start
;
if
(
t_id
>=
end
)
return
;
if
(
t_id
>=
end
)
return
;
int
offset_block
=
t_id
*
stride_t
;
offset_block
=
t_id
*
stride_s_or_t
;
int
offset_block_dst
=
t_id
*
o_stride_t
;
offset_block_dst
=
t_id
*
o_stride_s_or_t
;
cur_seqlens
=
end
-
start
;
}
else
{
// SBHD/BSHD
offset_block
=
s_id
*
stride_s_or_t
+
b_id
*
stride_b
;
offset_block_dst
=
s_id
*
o_stride_s_or_t
+
b_id
*
o_stride_b
;
cur_seqlens
=
s
;
}
int
s_id_for_freqs
;
int
s_id_for_freqs
;
if
(
cp_size
>
1
)
{
if
(
cp_size
>
1
)
{
int
cur_seqlens
=
end
-
start
;
assert
(
cur_seqlens
%
2
==
0
);
assert
(
cur_seqlens
%
2
==
0
);
if
(
s_id
<
cur_seqlens
/
2
)
{
if
(
s_id
<
cur_seqlens
/
2
)
{
s_id_for_freqs
=
s_id
+
cp_rank
*
cur_seqlens
/
2
;
s_id_for_freqs
=
s_id
+
cp_rank
*
cur_seqlens
/
2
;
...
@@ -142,28 +151,37 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu
...
@@ -142,28 +151,37 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu
}
else
{
}
else
{
s_id_for_freqs
=
s_id
;
s_id_for_freqs
=
s_id
;
}
}
fused_rope_block_forward
(
src
,
freqs
,
dst
,
s_id_for_freqs
,
offset_block
,
offset_block_dst
,
h
,
d
,
d2
,
stride_h
,
stride_d
,
o_stride_h
,
o_stride_d
);
fused_rope_block_forward
(
src
,
freqs
,
dst
,
interleaved
,
s_id_for_freqs
,
offset_block
,
offset_block_dst
,
h
,
d
,
d2
,
stride_h
,
stride_d
,
o_stride_h
,
o_stride_d
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
fused_rope_
thd_
backward_kernel
(
const
scalar_t
*
src
,
const
int
*
cu_seqlens
,
__global__
void
fused_rope_backward_kernel
(
const
float
*
freqs
,
scalar_t
*
dst
,
const
int
cp_size
,
const
scalar_t
*
src
,
const
int
*
cu_seqlens
,
const
float
*
freqs
,
scalar_t
*
dst
,
const
int
cp_rank
,
const
int
h
,
const
int
d
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
h
,
const
int
d2
,
const
int
stride_
t
,
const
int
stride_h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_
b
,
const
int
stride_h
,
const
int
stride_
d
,
const
int
o_stride_
t
,
const
int
stride_d
,
const
int
o_stride_s_or_t
,
const
int
o_
stride_
b
,
const
int
o_stride_
h
,
const
int
o_stride_h
,
const
int
o_stride_d
)
{
const
int
o_stride_d
)
{
int
s_id
=
blockIdx
.
x
,
b_id
=
blockIdx
.
y
;
int
s_id
=
blockIdx
.
x
,
b_id
=
blockIdx
.
y
;
int
offset_block
,
offset_block_dst
;
int
cur_seqlens
;
if
(
cu_seqlens
!=
nullptr
)
{
// THD
int
start
=
cu_seqlens
[
b_id
]
/
cp_size
;
int
start
=
cu_seqlens
[
b_id
]
/
cp_size
;
int
end
=
cu_seqlens
[
b_id
+
1
]
/
cp_size
;
int
end
=
cu_seqlens
[
b_id
+
1
]
/
cp_size
;
int
t_id
=
s_id
+
start
;
int
t_id
=
s_id
+
start
;
if
(
t_id
>=
end
)
return
;
if
(
t_id
>=
end
)
return
;
int
offset_block
=
t_id
*
stride_t
;
offset_block
=
t_id
*
stride_s_or_t
;
int
offset_block_dst
=
t_id
*
o_stride_t
;
offset_block_dst
=
t_id
*
o_stride_s_or_t
;
cur_seqlens
=
end
-
start
;
}
else
{
// SBHD/BSHD
offset_block
=
s_id
*
stride_s_or_t
+
b_id
*
stride_b
;
offset_block_dst
=
s_id
*
o_stride_s_or_t
+
b_id
*
o_stride_b
;
cur_seqlens
=
s
;
}
int
s_id_for_freqs
;
int
s_id_for_freqs
;
if
(
cp_size
>
1
)
{
if
(
cp_size
>
1
)
{
int
cur_seqlens
=
end
-
start
;
assert
(
cur_seqlens
%
2
==
0
);
assert
(
cur_seqlens
%
2
==
0
);
if
(
s_id
<
cur_seqlens
/
2
)
{
if
(
s_id
<
cur_seqlens
/
2
)
{
s_id_for_freqs
=
s_id
+
cp_rank
*
cur_seqlens
/
2
;
s_id_for_freqs
=
s_id
+
cp_rank
*
cur_seqlens
/
2
;
...
@@ -174,193 +192,136 @@ __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *c
...
@@ -174,193 +192,136 @@ __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *c
}
else
{
}
else
{
s_id_for_freqs
=
s_id
;
s_id_for_freqs
=
s_id
;
}
}
fused_rope_block_backward
(
src
,
freqs
,
dst
,
s_id_for_freqs
,
offset_block
,
offset_block_dst
,
h
,
d
,
d2
,
stride_h
,
stride_d
,
o_stride_h
,
o_stride_d
);
fused_rope_block_backward
(
src
,
freqs
,
dst
,
interleaved
,
s_id_for_freqs
,
offset_block
,
offset_block_dst
,
h
,
d
,
d2
,
stride_h
,
stride_d
,
o_stride_h
,
o_stride_d
);
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
fused_rope_forward_launcher
(
const
scalar_t
*
input
,
const
float
*
freqs
,
scalar_t
*
output
,
void
fused_rope_forward_launcher
(
const
scalar_t
*
input
,
const
int
*
cu_seqlens
,
const
float
*
freqs
,
scalar_t
*
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_s
,
const
int
o_stride_b
,
const
int
stride_d
,
cudaStream_t
stream
)
{
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
)
{
int
warps_per_block
=
h
<
16
?
4
:
8
;
int
warps_per_block
=
h
<
16
?
4
:
8
;
dim3
blocks
(
s
,
b
);
dim3
blocks
(
s
,
b
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
int
o_stride_s_or_t
,
o_stride_b
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
NVTE_CHECK
(
cu_seqlens
!=
nullptr
,
"cu_seqlens is required for THD format"
);
o_stride_s_or_t
=
h
*
d
;
o_stride_b
=
0
;
}
else
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
{
o_stride_s_or_t
=
b
*
h
*
d
;
o_stride_b
=
h
*
d
;
}
else
{
o_stride_s_or_t
=
h
*
d
;
o_stride_b
=
s
*
h
*
d
;
}
const
int
o_stride_h
=
d
;
const
int
o_stride_d
=
1
;
fused_rope_forward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
fused_rope_forward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
input
,
freqs
,
output
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
o_
stride_s
,
input
,
cu_seqlens
,
freqs
,
output
,
interleaved
,
cp_size
,
cp_rank
,
s
,
h
,
d
,
d2
,
stride_s
_or_t
,
o_stride_b
,
o_stride_h
,
o_stride_d
);
stride_b
,
stride_h
,
stride_d
,
o_stride_s_or_t
,
o_stride_b
,
o_stride_h
,
o_stride_d
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
fused_rope_backward_launcher
(
const
scalar_t
*
output_grads
,
const
float
*
freqs
,
void
fused_rope_backward_launcher
(
const
scalar_t
*
output_grads
,
const
int
*
cu_seqlens
,
scalar_t
*
input_grads
,
const
int
s
,
const
int
b
,
const
int
h
,
const
float
*
freqs
,
scalar_t
*
input_grads
,
const
int
d
,
const
int
d2
,
const
int
stride_s
,
const
int
stride_b
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_s
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
o_stride_b
,
const
int
o_stride_h
,
const
int
o_stride_d
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
int
warps_per_block
=
h
<
16
?
4
:
8
;
int
warps_per_block
=
h
<
16
?
4
:
8
;
dim3
blocks
(
s
,
b
);
dim3
blocks
(
s
,
b
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
int
o_stride_s_or_t
,
o_stride_b
;
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_THD
)
{
NVTE_CHECK
(
cu_seqlens
!=
nullptr
,
"cu_seqlens is required for THD format"
);
o_stride_s_or_t
=
h
*
d
;
o_stride_b
=
0
;
}
else
if
(
qkv_format
==
NVTE_QKV_Format
::
NVTE_SBHD
)
{
o_stride_s_or_t
=
b
*
h
*
d
;
o_stride_b
=
h
*
d
;
}
else
{
o_stride_s_or_t
=
h
*
d
;
o_stride_b
=
s
*
h
*
d
;
}
const
int
o_stride_h
=
d
;
const
int
o_stride_d
=
1
;
fused_rope_backward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
fused_rope_backward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
output_grads
,
freqs
,
input_grads
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
output_grads
,
cu_seqlens
,
freqs
,
input_grads
,
interleaved
,
cp_size
,
cp_rank
,
s
,
h
,
d
,
d2
,
o_stride_s
,
o_stride_b
,
o_stride_h
,
o_stride_d
);
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s_or_t
,
o_stride_b
,
o_stride_h
,
NVTE_CHECK_CUDA
(
cudaGetLastError
());
o_stride_d
);
}
template
<
typename
scalar_t
>
void
fused_rope_thd_forward_launcher
(
const
scalar_t
*
input
,
const
int
*
cu_seqlens
,
const
float
*
freqs
,
scalar_t
*
output
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
max_s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_t
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_t
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
)
{
int
warps_per_block
=
h
<
16
?
4
:
8
;
dim3
blocks
(
max_s
,
b
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
fused_rope_thd_forward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
input
,
cu_seqlens
,
freqs
,
output
,
cp_size
,
cp_rank
,
h
,
d
,
d2
,
stride_t
,
stride_h
,
stride_d
,
o_stride_t
,
o_stride_h
,
o_stride_d
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
}
template
<
typename
scalar_t
>
void
fused_rope_forward
(
const
Tensor
&
input
,
const
Tensor
&
cu_seqlens
,
const
Tensor
&
freqs
,
void
fused_rope_thd_backward_launcher
(
const
scalar_t
*
output_grads
,
const
int
*
cu_seqlens
,
Tensor
*
output
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
float
*
freqs
,
scalar_t
*
input_grads
,
const
int
cp_size
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
cp_rank
,
const
int
max_s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
d
,
const
int
d2
,
const
int
stride_t
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
)
{
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_t
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
)
{
int
warps_per_block
=
h
<
16
?
4
:
8
;
dim3
blocks
(
max_s
,
b
);
dim3
threads
(
THREADS_PER_WARP
,
warps_per_block
);
fused_rope_thd_backward_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
output_grads
,
cu_seqlens
,
freqs
,
input_grads
,
cp_size
,
cp_rank
,
h
,
d
,
d2
,
stride_t
,
stride_h
,
stride_d
,
o_stride_t
,
o_stride_h
,
o_stride_d
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
void
fused_rope_forward
(
const
Tensor
&
input
,
const
Tensor
&
freqs
,
Tensor
*
output
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_s
,
const
int
o_stride_b
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
scalar_t
,
input
.
data
.
dtype
,
scalar_t
,
fused_rope_forward_launcher
(
reinterpret_cast
<
const
scalar_t
*>
(
input
.
data
.
dptr
),
fused_rope_forward_launcher
(
reinterpret_cast
<
const
scalar_t
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
freqs
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
output
->
data
.
dptr
),
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s
,
o_stride_b
,
o_stride_h
,
o_stride_d
,
stream
););
}
void
fused_rope_backward
(
const
Tensor
&
output_grads
,
const
Tensor
&
freqs
,
Tensor
*
input_grads
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_s
,
const
int
o_stride_b
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
output_grads
.
data
.
dtype
,
scalar_t
,
fused_rope_backward_launcher
(
reinterpret_cast
<
const
scalar_t
*>
(
output_grads
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
freqs
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
input_grads
->
data
.
dptr
),
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s
,
o_stride_b
,
o_stride_h
,
o_stride_d
,
stream
););
}
void
fused_rope_thd_forward
(
const
Tensor
&
input
,
const
Tensor
&
cu_seqlens
,
const
Tensor
&
freqs
,
Tensor
*
output
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
max_s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_t
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_t
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
scalar_t
,
fused_rope_thd_forward_launcher
(
reinterpret_cast
<
const
scalar_t
*>
(
input
.
data
.
dptr
),
reinterpret_cast
<
const
int
*>
(
cu_seqlens
.
data
.
dptr
),
reinterpret_cast
<
const
int
*>
(
cu_seqlens
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
freqs
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
freqs
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
output
->
data
.
dptr
),
cp_size
,
reinterpret_cast
<
scalar_t
*>
(
output
->
data
.
dptr
),
qkv_format
,
cp_rank
,
max_
s
,
b
,
h
,
d
,
d2
,
stride_
t
,
stride_h
,
stride_d
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_
s_or_t
,
o_
stride_
t
,
o_
stride_h
,
o_
stride_d
,
stream
););
stride_
b
,
stride_h
,
stride_d
,
stream
););
}
}
void
fused_rope_
thd_
backward
(
const
Tensor
&
output_grads
,
const
Tensor
&
cu_seqlens
,
void
fused_rope_backward
(
const
Tensor
&
output_grads
,
const
Tensor
&
cu_seqlens
,
const
Tensor
&
freqs
,
const
Tensor
&
freqs
,
Tensor
*
input_grads
,
const
int
cp_size
,
Tensor
*
input_grads
,
const
NVTE_QKV_Format
qkv_format
,
const
int
cp_rank
,
const
int
max_s
,
const
int
b
,
const
int
h
,
const
bool
interleaved
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
d
,
const
int
d2
,
const
int
stride_t
,
const
int
stride_h
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_
d
,
const
int
o_
stride_
t
,
const
int
o_
stride_h
,
const
int
stride_
s_or_t
,
const
int
stride_
b
,
const
int
stride_h
,
const
int
o_
stride_d
,
cudaStream_t
stream
)
{
const
int
stride_d
,
cudaStream_t
stream
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
output_grads
.
data
.
dtype
,
scalar_t
,
output_grads
.
data
.
dtype
,
scalar_t
,
fused_rope_
thd_
backward_launcher
(
reinterpret_cast
<
const
scalar_t
*>
(
output_grads
.
data
.
dptr
),
fused_rope_backward_launcher
(
reinterpret_cast
<
const
scalar_t
*>
(
output_grads
.
data
.
dptr
),
reinterpret_cast
<
const
int
*>
(
cu_seqlens
.
data
.
dptr
),
reinterpret_cast
<
const
int
*>
(
cu_seqlens
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
freqs
.
data
.
dptr
),
reinterpret_cast
<
const
float
*>
(
freqs
.
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
input_grads
->
data
.
dptr
),
reinterpret_cast
<
scalar_t
*>
(
input_grads
->
data
.
dptr
),
qkv_format
,
cp_size
,
cp_rank
,
max_
s
,
b
,
h
,
d
,
d2
,
stride_
t
,
stride_h
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_
s_or_t
,
stride_
d
,
o_stride_t
,
o_
stride_h
,
o_
stride_d
,
stream
););
stride_
b
,
stride_h
,
stride_d
,
stream
););
}
}
}
// end namespace transformer_engine
}
// end namespace transformer_engine
void
nvte_fused_rope_forward
(
const
NVTETensor
input
,
const
NVTETensor
freqs
,
NVTETensor
output
,
void
nvte_fused_rope_forward
(
const
NVTETensor
input
,
const
NVTETensor
cu_seqlens
,
const
int
s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
NVTETensor
freqs
,
NVTETensor
output
,
const
int
stride_s
,
const
int
stride_b
,
const
int
stride_h
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
stride_d
,
const
int
o_stride_s
,
const
int
o_stride_b
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
)
{
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_rope_forward
);
NVTE_API_CALL
(
nvte_fused_rope_forward
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
fused_rope_forward
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
fused_rope_forward
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
const
Tensor
*>
(
freqs
),
reinterpret_cast
<
Tensor
*>
(
output
),
*
reinterpret_cast
<
const
Tensor
*>
(
freqs
),
reinterpret_cast
<
Tensor
*>
(
output
),
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s
,
o_stride_b
,
qkv_format
,
interleaved
,
cp_size
,
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s_or_t
,
o_
stride_h
,
o_
stride_d
,
stream
);
stride_b
,
stride_h
,
stride_d
,
stream
);
}
}
void
nvte_fused_rope_backward
(
const
NVTETensor
output_grads
,
const
NVTETensor
freqs
,
void
nvte_fused_rope_backward
(
const
NVTETensor
output_grads
,
const
NVTETensor
cu_seqlens
,
NVTETensor
input_grads
,
const
int
s
,
const
int
b
,
const
int
h
,
const
NVTETensor
freqs
,
NVTETensor
input_grads
,
const
int
d
,
const
int
d2
,
const
int
stride_s
,
const
int
stride_b
,
const
NVTE_QKV_Format
qkv_format
,
const
bool
interleaved
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_s
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
s
,
const
int
b
,
const
int
o_stride_b
,
const
int
o_stride_h
,
const
int
o_stride_d
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_s_or_t
,
const
int
stride_b
,
const
int
stride_h
,
const
int
stride_d
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_rope_backward
);
NVTE_API_CALL
(
nvte_fused_rope_backward
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
fused_rope_backward
(
*
reinterpret_cast
<
const
Tensor
*>
(
output_grads
),
fused_rope_backward
(
*
reinterpret_cast
<
const
Tensor
*>
(
output_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
freqs
),
reinterpret_cast
<
Tensor
*>
(
input_grads
),
s
,
b
,
h
,
d
,
d2
,
stride_s
,
stride_b
,
stride_h
,
stride_d
,
o_stride_s
,
o_stride_b
,
o_stride_h
,
o_stride_d
,
stream
);
}
void
nvte_fused_rope_thd_forward
(
const
NVTETensor
input
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
freqs
,
NVTETensor
output
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
max_s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_t
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_t
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_rope_thd_forward
);
using
namespace
transformer_engine
;
fused_rope_thd_forward
(
*
reinterpret_cast
<
const
Tensor
*>
(
input
),
*
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
const
Tensor
*>
(
freqs
),
*
reinterpret_cast
<
const
Tensor
*>
(
freqs
),
reinterpret_cast
<
Tensor
*>
(
output
),
cp_size
,
cp_rank
,
max_s
,
b
,
h
,
d
,
d2
,
reinterpret_cast
<
Tensor
*>
(
input_grads
),
qkv_format
,
interleaved
,
cp_size
,
stride_t
,
stride_h
,
stride_d
,
o_stride_t
,
o_stride_h
,
o_stride_d
,
stream
);
cp_rank
,
s
,
b
,
h
,
d
,
d2
,
stride_s_or_t
,
stride_b
,
stride_h
,
stride_d
,
stream
);
}
void
nvte_fused_rope_thd_backward
(
const
NVTETensor
output_grads
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
freqs
,
NVTETensor
input_grads
,
const
int
cp_size
,
const
int
cp_rank
,
const
int
max_s
,
const
int
b
,
const
int
h
,
const
int
d
,
const
int
d2
,
const
int
stride_t
,
const
int
stride_h
,
const
int
stride_d
,
const
int
o_stride_t
,
const
int
o_stride_h
,
const
int
o_stride_d
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_fused_rope_thd_backward
);
using
namespace
transformer_engine
;
fused_rope_thd_backward
(
*
reinterpret_cast
<
const
Tensor
*>
(
output_grads
),
*
reinterpret_cast
<
const
Tensor
*>
(
cu_seqlens
),
*
reinterpret_cast
<
const
Tensor
*>
(
freqs
),
reinterpret_cast
<
Tensor
*>
(
input_grads
),
cp_size
,
cp_rank
,
max_s
,
b
,
h
,
d
,
d2
,
stride_t
,
stride_h
,
stride_d
,
o_stride_t
,
o_stride_h
,
o_stride_d
,
stream
);
}
}
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment