Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a207db1d
Commit
a207db1d
authored
Apr 01, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
fbee8990
69365f88
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
154 additions
and
64 deletions
+154
-64
tests/jax/utils.py
tests/jax/utils.py
+92
-17
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+0
-1
tests/pytorch/distributed/test_fusible_ops.py
tests/pytorch/distributed/test_fusible_ops.py
+0
-1
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
+6
-0
tests/pytorch/fused_attn/test_kv_cache.py
tests/pytorch/fused_attn/test_kv_cache.py
+5
-5
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+1
-0
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+0
-1
tests/pytorch/test_jit.py
tests/pytorch/test_jit.py
+7
-0
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+1
-1
tests/pytorch/test_recipe.py
tests/pytorch/test_recipe.py
+0
-1
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+8
-21
transformer_engine/__init__.py
transformer_engine/__init__.py
+0
-5
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+2
-1
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+10
-0
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+7
-0
transformer_engine/common/include/transformer_engine/normalization.h
..._engine/common/include/transformer_engine/normalization.h
+2
-0
transformer_engine/common/include/transformer_engine/transformer_engine.h
...ne/common/include/transformer_engine/transformer_engine.h
+9
-1
transformer_engine/common/libtransformer_engine.version
transformer_engine/common/libtransformer_engine.version
+3
-1
transformer_engine/common/normalization/common.h
transformer_engine/common/normalization/common.h
+1
-1
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+0
-7
No files found.
tests/jax/utils.py
View file @
a207db1d
...
...
@@ -18,13 +18,14 @@ from flax.linen.attention import combine_masks
from
jax
import
lax
,
vmap
from
jax
import
nn
as
jax_nn
from
jax
import
random
as
jax_random
import
pytest
from
transformer_engine.jax.attention
import
(
AttnMaskType
,
canonicalize_attn_mask_type
,
make_swa_mask
,
)
from
transformer_engine.jax.
fp8
import
DType
as
TEDType
from
transformer_engine.jax.
quantize.helper
import
DType
as
TEDType
PRNGKey
=
Any
Shape
=
Tuple
[
int
,
...]
...
...
@@ -96,6 +97,62 @@ def combine_biases(*masks: Optional[Array]):
return
mask
def
parameterize_by_test_level
(
param_dict
:
dict
,
id_prefix
:
str
=
""
):
"""
Takes an input dictionary of parameters keyed by test type "L0", etc.
Returns a list of pytest parameters to be used in a parameterized test for the current test type
"""
DEFAULT_TEST_LEVEL
=
"L0"
test_level
=
os
.
environ
.
get
(
"NVTE_JAX_UNITTEST_LEVEL"
,
DEFAULT_TEST_LEVEL
)
if
test_level
not
in
param_dict
:
raise
ValueError
(
"Unsupported test level"
)
return
values_to_named_params
(
param_dict
[
test_level
],
id_prefix
)
def
value_to_test_name_str
(
value
):
"""Converts a value to how it should appear in a test name."""
if
isinstance
(
value
,
tuple
)
or
isinstance
(
value
,
list
):
return
"_"
.
join
([
value_to_test_name_str
(
v
)
for
v
in
value
])
dtype_type
=
type
(
jnp
.
float32
)
if
isinstance
(
value
,
dtype_type
):
return
value
.
dtype
return
str
(
value
)
def
value_to_named_param
(
value
,
id_prefix
:
str
=
""
):
param_type
=
type
(
pytest
.
param
(
0
))
if
isinstance
(
value
,
param_type
):
return
value
x
=
pytest
.
param
(
value
,
id
=
f
"
{
id_prefix
}
_
{
value_to_test_name_str
(
value
)
}
"
)
return
x
def
values_to_named_params
(
params
,
id_prefix
:
str
=
""
):
return
[
value_to_named_param
(
v
,
id_prefix
=
id_prefix
)
for
v
in
params
]
def
pytest_parametrize_wrapper
(
param_name
,
param_values
):
"""
A wrapper for pytest.mark.parametrize to allow for automatic
naming of tests based on the parameter values.
"""
id_prefix
=
param_name
if
isinstance
(
param_values
,
dict
):
param_values
=
parameterize_by_test_level
(
param_values
,
id_prefix
=
param_name
)
elif
","
not
in
param_name
:
param_values
=
values_to_named_params
(
param_values
,
id_prefix
=
id_prefix
)
# Currently comma separated parameters in one parametrize call aren't supported for automatic naming
# and will just be passed through with default pytest names
def
decorator
(
func
):
return
pytest
.
mark
.
parametrize
(
param_name
,
param_values
)(
func
)
return
decorator
class
DotProductAttention
(
nn
.
Module
):
transpose_batch_sequence
:
bool
=
True
scale_attn_logits
:
bool
=
True
...
...
@@ -140,6 +197,7 @@ class DotProductAttention(nn.Module):
Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
"""
input_dtype
=
query
.
dtype
assert
key
.
ndim
==
query
.
ndim
==
value
.
ndim
,
"q, k, v must have same rank."
batch_dim
=
1
if
self
.
transpose_batch_sequence
else
0
assert
(
...
...
@@ -152,7 +210,7 @@ class DotProductAttention(nn.Module):
if
self
.
scale_attn_logits
:
head_dim
=
query
.
shape
[
-
1
]
depth_scaling
=
jnp
.
sqrt
(
head_dim
).
astype
(
self
.
dtype
)
depth_scaling
=
jnp
.
sqrt
(
head_dim
).
astype
(
input_
dtype
)
query
=
query
/
depth_scaling
# Casting logits and softmax computation for float32 for model stability.
...
...
@@ -181,7 +239,7 @@ class DotProductAttention(nn.Module):
attn_weights
=
attn_weights
+
bias
.
astype
(
attn_weights
.
dtype
)
# Normalize the attention weights across `kv_length` dimension.
attn_weights
=
jax_nn
.
softmax
(
attn_weights
).
astype
(
self
.
dtype
)
attn_weights
=
jax_nn
.
softmax
(
attn_weights
).
astype
(
input_
dtype
)
# Apply attention dropout.
if
not
deterministic
and
self
.
dropout_rate
>
0.0
:
...
...
@@ -191,16 +249,20 @@ class DotProductAttention(nn.Module):
dropout_shape
=
list
(
attn_weights
.
shape
)
dropout_rng
=
self
.
make_rng
(
"dropout"
)
keep
=
jax_random
.
bernoulli
(
dropout_rng
,
keep_prob
,
dropout_shape
)
multiplier
=
keep
.
astype
(
attn_weights
.
dtype
)
/
jnp
.
asarray
(
keep_prob
,
dtype
=
self
.
dtype
)
multiplier
=
keep
.
astype
(
input_
dtype
)
/
jnp
.
asarray
(
keep_prob
,
dtype
=
input_
dtype
)
attn_weights
=
attn_weights
*
multiplier
attn_weights
=
attn_weights
.
reshape
(
attn_weights_with_groups_shape
)
attn_weights
=
attn_weights
.
astype
(
value
.
dtype
)
#
attn_weights = attn_weights.astype(
input_
dtype)
# Take the linear combination of `value`.
if
self
.
transpose_batch_sequence
:
return
jnp
.
einsum
(
"bhgqk,kbhd->qbhgd"
,
attn_weights
,
value
).
reshape
(
query
.
shape
)
assert
(
attn_weights
.
dtype
==
input_dtype
),
f
"input.dtype=
{
input_dtype
}
, output.dtype=
{
attn_weights
.
dtype
}
"
return
jnp
.
einsum
(
"bhgqk,bkhd->bqhgd"
,
attn_weights
,
value
).
reshape
(
query
.
shape
)
...
...
@@ -246,7 +308,6 @@ class DenseGeneral(nn.Module):
features
=
_canonicalize_tuple
(
self
.
features
)
axis
=
_canonicalize_tuple
(
self
.
axis
)
inputs
=
jnp
.
asarray
(
inputs
,
self
.
dtype
)
axis
=
_normalize_axes
(
axis
,
inputs
.
ndim
)
kernel_shape
=
tuple
(
inputs
.
shape
[
ax
]
for
ax
in
axis
)
+
features
...
...
@@ -268,11 +329,14 @@ class DenseGeneral(nn.Module):
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
y
=
lax
.
dot_general
(
inputs
,
kernel
,
((
axis
,
contract_ind
),
((),
())))
y
=
y
.
astype
(
input_dtype
)
y
=
lax
.
dot_general
(
inputs
,
kernel
,
((
axis
,
contract_ind
),
((),
())),
preferred_element_type
=
input_dtype
)
if
bias
is
not
None
:
y
+=
jnp
.
reshape
(
bias
,
(
1
,)
*
(
y
.
ndim
-
1
)
+
(
-
1
,))
assert
y
.
dtype
==
inputs
.
dtype
,
f
"input.dtype=
{
inputs
.
dtype
}
, output.dtype=
{
y
.
dtype
}
"
return
y
...
...
@@ -352,6 +416,7 @@ class MlpBlock(nn.Module):
)(
x
,
deterministic
=
deterministic
)
# Broadcast along length.
if
self
.
transpose_batch_sequence
:
x
=
nn_partitioning
.
with_sharding_constraint
(
x
,
(
"length"
,
"batch"
,
"mlp"
))
else
:
...
...
@@ -365,6 +430,7 @@ class MlpBlock(nn.Module):
bias_axes
=
"embed"
,
name
=
"wo"
,
)(
x
)
assert
(
output
.
dtype
==
inputs
.
dtype
),
f
"input.dtype=
{
input
.
dtype
}
, output.dtype=
{
output
.
dtype
}
"
...
...
@@ -391,7 +457,7 @@ def apply_rotary_pos_emb_alternate(
second_part
=
second_half
*
cos
+
first_half
*
sin
first_part
=
first_part
.
astype
(
inputs
.
dtype
)
second_part
=
second_part
.
astype
(
inputs
.
dtype
)
return
jnp
.
concatenate
([
first_part
,
second_part
],
axis
=-
1
)
return
jnp
.
concatenate
([
first_part
,
second_part
],
axis
=-
1
)
.
astype
(
inputs
.
dtype
)
def
apply_rotary_pos_emb_consecutive
(
...
...
@@ -425,7 +491,7 @@ def apply_rotary_pos_emb_consecutive(
sign
=
jnp
.
sign
(
jnp
.
mod
(
jnp
.
arange
(
embedding_dim
,
dtype
=
jnp
.
int32
),
2
)
-
0.5
)
outputs
=
inputs
*
cos
+
inputs_shifted
*
sin
*
sign
return
outputs
return
outputs
.
astype
(
inputs
.
dtype
)
dynamic_vector_slice_in_dim
=
vmap
(
lax
.
dynamic_slice_in_dim
,
in_axes
=
(
None
,
0
,
None
,
None
))
...
...
@@ -559,6 +625,7 @@ class MultiHeadAttention(nn.Module):
if
self
.
fuse_qkv
:
if
is_qkvpack
:
qkv_proj
=
DenseGeneral
(
axis
=-
1
,
features
=
self
.
num_heads
*
self
.
head_dim
*
3
,
...
...
@@ -569,11 +636,13 @@ class MultiHeadAttention(nn.Module):
name
=
"qkv"
,
dtype
=
self
.
dtype
,
)(
inputs_kv
)
query
,
key
,
value
=
jnp
.
split
(
qkv_proj
,
[
self
.
num_heads
*
self
.
head_dim
,
self
.
num_heads
*
self
.
head_dim
*
2
],
axis
=-
1
,
)
else
:
query
=
q_projection
(
kernel_init
=
query_init
,
name
=
"query"
)(
inputs_q
)
...
...
@@ -711,6 +780,7 @@ class MultiHeadAttention(nn.Module):
# Convert the boolean attention mask to an attention bias.
if
mask
is
not
None
:
# attention mask in the form of attention bias
attention_bias
=
lax
.
select
(
mask
>
0
,
jnp
.
full
(
mask
.
shape
,
0.0
).
astype
(
self
.
dtype
),
...
...
@@ -740,6 +810,7 @@ class MultiHeadAttention(nn.Module):
x
=
nn_partitioning
.
with_sharding_constraint
(
x
,
(
"batch"
,
"length"
,
"joined_kv"
))
# Back to the original inputs dimensions.
out
=
DenseGeneral
(
features
=
inputs_q
.
shape
[
-
1
],
# output dim is set to the input dim.
axis
=-
1
,
...
...
@@ -750,6 +821,7 @@ class MultiHeadAttention(nn.Module):
dtype
=
self
.
dtype
,
name
=
"out"
,
)(
x
)
assert
(
inputs_q
.
dtype
==
inputs_kv
.
dtype
==
out
.
dtype
),
f
"q.dtype=
{
inputs_q
.
dtype
}
, kv.dtype=
{
inputs_kv
.
dtype
}
, out.dtype=
{
out
.
dtype
}
"
...
...
@@ -784,12 +856,11 @@ class LayerNorm(nn.Module):
scale
=
nn_partitioning
.
param_with_axes
(
"scale"
,
self
.
scale_init
,
(
features
,),
self
.
dtype
,
axes
=
(
"embed"
,)
)
scale
=
jnp
.
asarray
(
scale
,
input_dtype
)
x_
=
x
.
astype
(
jnp
.
float32
)
if
self
.
layernorm_type
==
"layernorm"
:
mean
=
jnp
.
mean
(
x
,
axis
=-
1
,
keepdims
=
True
)
var
=
jnp
.
mean
(
jnp
.
square
(
x
-
mean
),
axis
=-
1
,
keepdims
=
True
)
y
=
(
x
-
mean
)
*
lax
.
rsqrt
(
var
+
self
.
epsilon
)
mean
=
jnp
.
mean
(
x
_
,
axis
=-
1
,
keepdims
=
True
)
var
=
jnp
.
mean
(
jnp
.
square
(
x
_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
y
=
(
x
_
-
mean
)
*
lax
.
rsqrt
(
var
+
self
.
epsilon
)
bias
=
nn_partitioning
.
param_with_axes
(
"ln_bias"
,
self
.
bias_init
,
(
features
,),
self
.
dtype
,
axes
=
(
"embed"
,)
...
...
@@ -803,9 +874,10 @@ class LayerNorm(nn.Module):
else
:
assert
self
.
layernorm_type
==
"rmsnorm"
assert
not
self
.
zero_centered_gamma
mean2
=
jnp
.
mean
(
lax
.
square
(
x
),
axis
=-
1
,
keepdims
=
True
)
y
=
x
*
lax
.
rsqrt
(
mean2
+
self
.
epsilon
)
mean2
=
jnp
.
mean
(
lax
.
square
(
x
_
),
axis
=-
1
,
keepdims
=
True
)
y
=
x
_
*
lax
.
rsqrt
(
mean2
+
self
.
epsilon
)
z
=
y
*
scale
z
=
z
.
astype
(
input_dtype
)
assert
z
.
dtype
==
x
.
dtype
,
f
"output_dtype=
{
z
.
dtype
}
, input_dtype=
{
x
.
dtype
}
"
return
z
...
...
@@ -1085,9 +1157,11 @@ class EncoderLayer(nn.Module):
fuse_wi
=
self
.
fuse_mlp_wi
,
name
=
"mlp"
,
)(
y
,
deterministic
=
deterministic
)
y
=
nn
.
Dropout
(
rate
=
self
.
hidden_dropout
,
broadcast_dims
=
self
.
hidden_dropout_dims
)(
y
,
deterministic
=
deterministic
)
if
self
.
drop_path
>
0.0
:
drop_path_shape
=
_generate_drop_path_shape
(
y
.
shape
,
batch_dim
)
y
=
nn
.
Dropout
(
rate
=
self
.
drop_path
,
broadcast_dims
=
drop_path_shape
)(
...
...
@@ -1103,6 +1177,7 @@ class EncoderLayer(nn.Module):
dtype
=
self
.
dtype
,
name
=
"output_layernorm"
,
)(
y
)
assert
y
.
dtype
==
inputs
.
dtype
,
f
"output_dtype=
{
y
.
dtype
}
, input_dtype=
{
inputs
.
dtype
}
"
return
y
...
...
tests/pytorch/distributed/run_numerics.py
View file @
a207db1d
...
...
@@ -318,7 +318,6 @@ def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
device
=
device
,
with_amax_reduction
=
True
,
amax_reduction_group
=
tp_group
,
amax_reduction_size
=
tp_size
,
)
quantizer
=
quantizer_class
(
fp8_dtype
=
fp8_dtype
,
...
...
tests/pytorch/distributed/test_fusible_ops.py
View file @
a207db1d
...
...
@@ -741,7 +741,6 @@ def _test_fp8_scale_update(
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
recipe
=
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
margin
=
margin
,
interval
=
1
,
fp8_format
=
fp8_format
,
amax_history_len
=
amax_history_len
,
amax_compute_algo
=
amax_compute_algo
,
...
...
tests/pytorch/fused_attn/run_fused_attn_with_cp.py
View file @
a207db1d
...
...
@@ -286,6 +286,12 @@ def run_dpa_with_cp(
else
:
out_
.
backward
(
dout_
)
if
fp8_mha
:
assert
isinstance
(
out
,
Float8Tensor
)
assert
isinstance
(
out_
,
Float8Tensor
)
out
=
out
.
dequantize
()
out_
=
out_
.
dequantize
()
for
x
in
[
out_
,
q_
.
grad
,
k_
.
grad
,
v_
.
grad
]:
assert
torch
.
all
(
~
torch
.
isnan
(
x
))
assert
torch
.
all
(
~
torch
.
isinf
(
x
))
...
...
tests/pytorch/fused_attn/test_
paged_attn
.py
→
tests/pytorch/fused_attn/test_
kv_cache
.py
View file @
a207db1d
...
...
@@ -229,7 +229,7 @@ def get_model(
attn_mask_type
=
"causal"
qkv_format
=
"bshd"
if
mode
==
"inference"
:
attn_mask_type
=
"padding_causal"
if
backend
!=
"FusedAttention"
else
"padding"
attn_mask_type
=
"padding_causal"
fp8_recipe
=
recipe
.
DelayedScaling
(
margin
=
0
,
...
...
@@ -392,9 +392,9 @@ def get_tols(module, backend, dtype):
@
pytest
.
mark
.
parametrize
(
"module"
,
[
"TransformerLayer"
,
"DotProductAttention"
])
@
pytest
.
mark
.
parametrize
(
"is_cuda_graph"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"is_fp8"
,
[
False
,
True
])
def
test_
paged_attn
(
dtype
,
model
,
qkv_format
,
is_paged
,
backend
,
module
,
is_cuda_graph
,
is_fp8
):
def
test_
kv_cache
(
dtype
,
model
,
qkv_format
,
is_paged
,
backend
,
module
,
is_cuda_graph
,
is_fp8
):
reset_rng_states
()
logger
=
logging
.
getLogger
(
"test_
paged_attn
"
)
logger
=
logging
.
getLogger
(
"test_
kv_cache
"
)
fp8_recipe
=
recipe
.
DelayedScaling
(
margin
=
0
,
fp8_format
=
recipe
.
Format
.
HYBRID
,
...
...
@@ -407,7 +407,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda
fp8_meta
[
"recipe"
]
=
fp8_recipe
config
=
model_configs_infer
[
model
]
num_layers
=
2
if
module
==
"TransformerLayer"
and
backend
!=
"FusedAttention"
else
1
num_layers
=
2
if
module
==
"TransformerLayer"
else
1
# flash-attn v2 requires page_size >= 256
if
backend
==
"FlashAttention"
and
not
fa_utils
.
v3_is_installed
:
config_max_seqlen_q
=
config
.
max_seqlen_q
...
...
@@ -437,7 +437,7 @@ def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda
# initialize inference_params
inference_params
=
InferenceParams
(
max_batch_size
=
max_batch_size
,
max_seq
l
en
_kv
=
config
.
max_seqlen_kv
,
max_seq
u
en
ce_length
=
config
.
max_seqlen_kv
,
num_heads_kv
=
config
.
num_gqa_groups
,
head_dim_k
=
config
.
head_dim_qk
,
head_dim_v
=
config
.
head_dim_v
,
...
...
tests/pytorch/test_cuda_graphs.py
View file @
a207db1d
...
...
@@ -57,6 +57,7 @@ model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
fp8_recipes
=
[
recipe
.
DelayedScaling
(),
recipe
.
MXFP8BlockScaling
(),
recipe
.
Float8CurrentScaling
(),
]
# Supported data types
...
...
tests/pytorch/test_fusible_ops.py
View file @
a207db1d
...
...
@@ -297,7 +297,6 @@ class TestFuser:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
recipe
=
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
margin
=
margin
,
interval
=
1
,
fp8_format
=
fp8_format
,
amax_history_len
=
8
,
amax_compute_algo
=
"max"
,
...
...
tests/pytorch/test_jit.py
View file @
a207db1d
...
...
@@ -56,3 +56,10 @@ def test_torch_dynamo(model_name: str):
# Forward and backward pass
out
=
model
(
*
inputs
)
out
.
backward
(
torch
.
zeros_like
(
out
))
def
test_lazy_compile
():
"""Smoke test to ensure lazy compilation is working."""
from
transformer_engine.pytorch.jit
import
dgelu_fused_
dgelu_fused_
(
torch
.
randn
(
10
,
10
),
torch
.
randn
(
10
,
10
))
tests/pytorch/test_numerics.py
View file @
a207db1d
...
...
@@ -2144,7 +2144,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
inference_params
=
InferenceParams
(
max_batch_size
=
B_max
,
max_seq
l
en
_kv
=
S_max
,
max_seq
u
en
ce_length
=
S_max
,
num_heads_kv
=
H
,
head_dim_k
=
head_size
,
dtype
=
dtype
,
...
...
tests/pytorch/test_recipe.py
View file @
a207db1d
...
...
@@ -177,7 +177,6 @@ class TestFP8Recipe:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
recipe
=
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
margin
=
margin
,
interval
=
1
,
fp8_format
=
fp8_format
,
amax_history_len
=
amax_history_len
,
amax_compute_algo
=
amax_compute_algo
,
...
...
tests/pytorch/test_sanity.py
View file @
a207db1d
...
...
@@ -110,32 +110,17 @@ model_configs = {
}
fp8_recipes
=
[
None
,
# Handles non-FP8 case
recipe
.
MXFP8BlockScaling
(),
recipe
.
DelayedScaling
(
margin
=
0
,
fp8_format
=
recipe
.
Format
.
E4M3
),
recipe
.
DelayedScaling
(
margin
=
0
,
fp8_format
=
recipe
.
Format
.
HYBRID
),
recipe
.
DelayedScaling
(
margin
=
0
,
fp8_format
=
recipe
.
Format
.
E4M3
,
None
,
# Test non-FP8
recipe
.
MXFP8BlockScaling
(),
# Test default
recipe
.
Float8CurrentScaling
(),
# Test default
recipe
.
DelayedScaling
(),
# Test default
recipe
.
DelayedScaling
(
# Test most_recent algo
amax_history_len
=
16
,
amax_compute_algo
=
"most_recent"
,
),
recipe
.
DelayedScaling
(
margin
=
0
,
fp8_format
=
recipe
.
Format
.
E4M3
,
amax_history_len
=
16
,
amax_compute_algo
=
"max"
,
),
recipe
.
DelayedScaling
(
margin
=
0
,
recipe
.
DelayedScaling
(
# Test custom amax and scale compute algo
fp8_format
=
recipe
.
Format
.
E4M3
,
amax_history_len
=
16
,
amax_compute_algo
=
custom_amax_compute
,
),
recipe
.
DelayedScaling
(
margin
=
0
,
fp8_format
=
recipe
.
Format
.
E4M3
,
amax_history_len
=
16
,
scaling_factor_compute_algo
=
custom_amax_to_scale
,
),
]
...
...
@@ -567,6 +552,8 @@ def test_sanity_grouped_linear(
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8_recipe
.
mxfp8
():
pytest
.
skip
(
"Grouped linear does not support MXFP8"
)
if
fp8_recipe
.
float8_current_scaling
():
pytest
.
skip
(
"Grouped linear does not support FP8 current scaling"
)
if
not
config
.
is_fp8_supported
():
pytest
.
skip
(
"Model config does not support FP8"
)
...
...
transformer_engine/__init__.py
View file @
a207db1d
...
...
@@ -19,9 +19,4 @@ try:
except
(
ImportError
,
StopIteration
)
as
e
:
pass
try
:
import
transformer_engine_jax
except
ImportError
:
pass
__version__
=
str
(
metadata
.
version
(
"transformer_engine"
))
transformer_engine/common/CMakeLists.txt
View file @
a207db1d
...
...
@@ -233,7 +233,8 @@ if (USE_CUDA)
# Configure dependencies
target_link_libraries
(
transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
)
CUDA::cudart
CUDNN::cudnn_all
)
target_include_directories
(
transformer_engine PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
a207db1d
...
...
@@ -771,6 +771,16 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
}
}
#ifndef __HIP_PLATFORM_AMD__
namespace
transformer_engine
{
using
cublasHandleManager
=
detail
::
HandleManager
<
cublasLtHandle_t
,
CreateCublasHandle
>
;
void
nvte_cublas_handle_init
()
{
auto
_
=
cublasHandleManager
::
Instance
().
GetHandle
();
}
}
// namespace transformer_engine
#endif
#ifdef __HIP_PLATFORM_AMD__
void
nvte_multi_stream_cublas_batchgemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
a207db1d
...
...
@@ -140,6 +140,13 @@ constexpr int num_batchgemm_streams = 1;
constexpr
int
num_streams
=
4
;
#endif
/*! \brief TE/JAX cudaGraph requires the cuBLAS initialization to happen outside of the capturing
* region. This function is a helper to call cublasCreate() which allocate memory for the handle.
* The function will be called in the initialize phase of the related XLA custom calls.
*/
void
nvte_cublas_handle_init
();
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_H_
transformer_engine/common/include/transformer_engine/normalization.h
View file @
a207db1d
...
...
@@ -149,6 +149,8 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
void
nvte_enable_cudnn_norm_fwd
(
bool
enable
);
void
nvte_enable_cudnn_norm_bwd
(
bool
enable
);
enum
class
NVTE_Norm_Type
{
LayerNorm
,
RMSNorm
};
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/include/transformer_engine/transformer_engine.h
View file @
a207db1d
...
...
@@ -80,7 +80,8 @@ enum NVTEScalingMode {
/*! Single scale per block of 32 elements consecutive in either
rowwise or columnwise direction */
NVTE_MXFP8_1D_SCALING
=
1
,
NVTE_INVALID_SCALING
NVTE_INVALID_SCALING
=
2
,
NVTE_NO_SCALING
=
3
};
/*! \brief TE Tensor type
...
...
@@ -346,6 +347,13 @@ enum class DType {
kNumTypes
};
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest
*/
bool
is_fp8_dtype
(
const
DType
t
);
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
*/
...
...
transformer_engine/common/libtransformer_engine.version
View file @
a207db1d
...
...
@@ -11,10 +11,12 @@
transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*;
transformer_engine::nvte_cublas_handle_init*;
transformer_engine::typeToSize*;
transformer_engine::is_fp8_dtype*;
*transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore*
};
local: *;
};
\ No newline at end of file
};
transformer_engine/common/normalization/common.h
View file @
a207db1d
...
...
@@ -12,6 +12,7 @@
#include <cudnn_frontend.h>
#include <cudnn_frontend_utils.h>
#endif
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include <functional>
...
...
@@ -141,7 +142,6 @@ struct BackwardKernelParams : public KernelParamsBase {
};
enum
class
NVTE_Norm_Backend
{
Te
,
Cudnn
};
enum
class
NVTE_Norm_Type
{
LayerNorm
,
RMSNorm
};
enum
class
NVTE_Norm_Stage
{
Forward
,
Backward
};
using
TupleKeyType
=
std
::
tuple
<
uint64_t
,
uint64_t
,
uint64_t
,
bool
>
;
...
...
transformer_engine/common/recipe/__init__.py
View file @
a207db1d
...
...
@@ -162,7 +162,6 @@ class DelayedScaling(Recipe):
"""
margin
:
int
=
0
interval
:
int
=
-
1
fp8_format
:
Format
=
Format
.
HYBRID
amax_history_len
:
int
=
1024
amax_compute_algo
:
Union
[
Literal
[
"max"
,
"most_recent"
],
Callable
]
=
"max"
...
...
@@ -173,12 +172,6 @@ class DelayedScaling(Recipe):
def
__post_init__
(
self
)
->
None
:
assert
self
.
fp8_format
!=
Format
.
E5M2
,
"Pure E5M2 training is not supported."
if
self
.
interval
>=
0
:
warnings
.
warn
(
"`interval` argument is deprecated and unused. "
"It will be removed in an upcoming release."
,
DeprecationWarning
,
)
def
__repr__
(
self
)
->
str
:
return
(
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment