Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
5cc8ee3e
Commit
5cc8ee3e
authored
Oct 30, 2025
by
zhaochao
Browse files
[DCU] fix some bug
Signed-off-by:
zhaochao
<
zhaochao1@sugon.com
>
parent
183a88cf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
73 additions
and
33 deletions
+73
-33
qa/L0_pytorch_unittest/test.sh
qa/L0_pytorch_unittest/test.sh
+1
-0
tests/pytorch/attention/test_attention.py
tests/pytorch/attention/test_attention.py
+15
-17
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+56
-7
transformer_engine/pytorch/attention/dot_product_attention/backends.py
...ngine/pytorch/attention/dot_product_attention/backends.py
+0
-8
transformer_engine/pytorch/module/_common.py
transformer_engine/pytorch/module/_common.py
+1
-1
No files found.
qa/L0_pytorch_unittest/test.sh
View file @
5cc8ee3e
...
@@ -51,6 +51,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp
...
@@ -51,6 +51,7 @@ NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_attention.xml
$TE_PATH
/tests/pytorch/attention/test_attention.py
||
test_fail
"test_attention.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_attention.xml
$TE_PATH
/tests/pytorch/attention/test_attention.py
||
test_fail
"test_attention.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/attention/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_kv_cache.xml
$TE_PATH
/tests/pytorch/attention/test_kv_cache.py
||
test_fail
"test_kv_cache.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_hf_integration.xml
$TE_PATH
/tests/pytorch/test_hf_integration.py
||
test_fail
"test_hf_integration.py"
mkdir
-p
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint
&&
python
$TE_PATH
/tests/pytorch/test_checkpoint.py
--save-checkpoint
all
--checkpoint-dir
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint/
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH
=
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_checkpoint.xml
$TE_PATH
/tests/pytorch/test_checkpoint.py
||
test_fail
"test_checkpoint.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH
=
$TE_PATH
/artifacts/tests/pytorch/test_checkpoint python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_checkpoint.xml
$TE_PATH
/tests/pytorch/test_checkpoint.py
||
test_fail
"test_checkpoint.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_router.xml
$TE_PATH
/tests/pytorch/test_fused_router.py
||
test_fail
"test_fused_router.py"
python3
-m
pytest
-v
-s
--junitxml
=
$XML_LOG_DIR
/pytest_test_fused_router.xml
$TE_PATH
/tests/pytorch/test_fused_router.py
||
test_fail
"test_fused_router.py"
...
...
tests/pytorch/attention/test_attention.py
View file @
5cc8ee3e
...
@@ -216,9 +216,6 @@ def test_dot_product_attention(
...
@@ -216,9 +216,6 @@ def test_dot_product_attention(
# FlashAttention backend
# FlashAttention backend
if
flash_attn_supported
:
if
flash_attn_supported
:
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
if
IS_HIP_EXTENSION
and
config
.
head_dim_qk
<
config
.
head_dim_v
:
pytest
.
skip
(
"FlashAttention on ROCm does not support MLA with head_dim_qk < head_dim_v"
)
flash_attn_fwd
,
flash_attn_bwd
=
_run_dot_product_attention
(
flash_attn_fwd
,
flash_attn_bwd
=
_run_dot_product_attention
(
dtype
,
dtype
,
config
,
config
,
...
@@ -263,21 +260,22 @@ def test_dpa_checkpoint(dtype, model_configs, model):
...
@@ -263,21 +260,22 @@ def test_dpa_checkpoint(dtype, model_configs, model):
model_configs_mla
=
{
model_configs_mla
=
{
#TODO:FlashAttention on ROCm only support MLA with head_dim_qk = head_dim_v
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0"
:
ModelConfig
(
8
,
128
,
16
,
64
,
head_dim_v
=
128
),
# self , 0
#
"mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0
"mla_1_1"
:
ModelConfig
(
4
,
128
,
16
,
64
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# cross, 0
#
"mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_1_2"
:
ModelConfig
(
4
,
128
,
16
,
192
,
max_seqlen_kv
=
256
,
head_dim_v
=
128
),
# cross, 0
#
"mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0
"mla_2_0"
:
ModelConfig
(
2
,
2048
,
24
,
128
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
),
# self , 1
#
"mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1
"mla_2_1"
:
ModelConfig
(
#
"mla_2_1": ModelConfig(
1
,
2048
,
24
,
128
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
64
#
1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64
),
# cross, 1
#
), # cross, 1
"mla_2_2"
:
ModelConfig
(
#
"mla_2_2": ModelConfig(
1
,
2048
,
24
,
192
,
max_seqlen_kv
=
4096
,
attn_mask_type
=
"causal"
,
head_dim_v
=
128
#
1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128
),
# cross, 1
#
), # cross, 1
"mla_3_0"
:
ModelConfig
(
8
,
1
,
16
,
128
,
max_seqlen_kv
=
2048
,
head_dim_v
=
64
),
# inference
#
"mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference
"mla_3_1"
:
ModelConfig
(
8
,
1
,
16
,
256
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
#
"mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_2"
:
ModelConfig
(
8
,
1
,
16
,
192
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
#
"mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_3"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
128
),
# inference
#
"mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# inference
"mla_3_4"
:
ModelConfig
(
8
,
1
,
16
,
160
,
max_seqlen_kv
=
2048
,
head_dim_v
=
160
),
# inference
}
}
...
...
tests/pytorch/test_numerics.py
View file @
5cc8ee3e
...
@@ -50,8 +50,8 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
...
@@ -50,8 +50,8 @@ from utils import ModelConfig, reset_rng_states, get_available_attention_backend
# Only run FP8 tests on supported devices.
# Only run FP8 tests on supported devices.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
_
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
_
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
sm_80plus
=
get_device_compute_capability
()
>=
(
8
,
0
)
...
@@ -582,6 +582,12 @@ def _test_e2e_selective_recompute(
...
@@ -582,6 +582,12 @@ def _test_e2e_selective_recompute(
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
def
test_gpt_selective_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
):
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -690,8 +696,14 @@ def _test_e2e_full_recompute(
...
@@ -690,8 +696,14 @@ def _test_e2e_full_recompute(
def
test_gpt_full_activation_recompute
(
def
test_gpt_full_activation_recompute
(
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
dtype
,
bs
,
model
,
fp8
,
recipe
,
fp8_model_params
,
use_reentrant
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
@@ -1277,10 +1289,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
...
@@ -1277,10 +1289,14 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
fuse_wgrad_accumulation
=
True
fuse_wgrad_accumulation
=
True
fp8_model_params
=
False
fp8_model_params
=
False
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
pytest
.
skip
(
"FP8 requires sequence length to be divisible by 16."
)
...
@@ -1793,6 +1809,12 @@ def test_grouped_linear_accuracy(
...
@@ -1793,6 +1809,12 @@ def test_grouped_linear_accuracy(
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
fp8
=
recipe
is
not
None
fp8
=
recipe
is
not
None
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8
and
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
...
@@ -1837,8 +1859,9 @@ def test_grouped_linear_accuracy(
...
@@ -1837,8 +1859,9 @@ def test_grouped_linear_accuracy(
if
fuse_wgrad_accumulation
:
if
fuse_wgrad_accumulation
:
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
=
getattr
(
grouped_linear
,
f
"weight
{
i
}
"
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"1"
outputs_ref
=
_test_grouped_linear_accuracy
(
outputs_ref
=
_test_grouped_linear_accuracy
(
sequential_linear
,
sequential_linear
,
num_gemms
,
num_gemms
,
...
@@ -1861,7 +1884,8 @@ def test_grouped_linear_accuracy(
...
@@ -1861,7 +1884,8 @@ def test_grouped_linear_accuracy(
fuse_wgrad_accumulation
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
delay_wgrad_compute
,
)
)
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"0"
# Shoule be bit-wise match
# Shoule be bit-wise match
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
outputs
,
outputs_ref
)):
for
i
,
(
o
,
o_ref
)
in
enumerate
(
zip
(
outputs
,
outputs_ref
)):
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
0
,
atol
=
0
)
...
@@ -1893,6 +1917,12 @@ def test_grouped_linear_accuracy_save_original_input(
...
@@ -1893,6 +1917,12 @@ def test_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
fp8
and
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8
and
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2099,8 +2129,14 @@ def test_padding_grouped_linear_accuracy(
...
@@ -2099,8 +2129,14 @@ def test_padding_grouped_linear_accuracy(
fp8_model_params
,
fp8_model_params
,
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
if
fp8_model_params
and
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2172,6 +2208,13 @@ def test_padding_grouped_linear_accuracy_save_original_input(
...
@@ -2172,6 +2208,13 @@ def test_padding_grouped_linear_accuracy_save_original_input(
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
fp8
and
recipe
.
delayed
():
if
fp8
and
recipe
.
delayed
():
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
pytest
.
skip
(
"DelayedScaling recipe is not supported with save_original_input"
)
if
fp8
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
if
config
.
max_seqlen_q
%
16
!=
0
and
fp8
:
...
@@ -2383,8 +2426,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
...
@@ -2383,8 +2426,14 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"126m"
])
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
def
test_gpt_fp8_parameters
(
dtype
,
bs
,
model
,
recipe
):
if
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
if
recipe
.
mxfp8
()
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
NVTE_TEST_NVINSPECT_ENABLED
:
if
NVTE_TEST_NVINSPECT_ENABLED
:
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
pytest
.
skip
(
"FP8 parameters are not supported in debug mode."
)
if
recipe
.
float8_block_scaling
()
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
config
=
model_configs
[
model
]
config
=
model_configs
[
model
]
...
...
transformer_engine/pytorch/attention/dot_product_attention/backends.py
View file @
5cc8ee3e
...
@@ -890,14 +890,6 @@ class FlashAttention(torch.nn.Module):
...
@@ -890,14 +890,6 @@ class FlashAttention(torch.nn.Module):
elif
q_format
==
"thd"
:
elif
q_format
==
"thd"
:
# thd -> t(hd)
# thd -> t(hd)
output
=
output
.
reshape
(
output
.
shape
[
0
],
-
1
)
output
=
output
.
reshape
(
output
.
shape
[
0
],
-
1
)
# Handle output shape when V head dim differs from Q/K head dim
if
value_layer
.
shape
[
-
1
]
!=
query_layer
.
shape
[
-
1
]:
v_dim
=
value_layer
.
shape
[
-
1
]
num_heads
=
query_layer
.
shape
[
-
2
]
out_shape_heads
=
output
.
shape
[:
-
1
]
+
(
num_heads
,
query_layer
.
shape
[
-
1
])
output
=
output
.
view
(
out_shape_heads
)[...,
:
v_dim
]
output
=
output
.
reshape
(
output
.
shape
[:
-
2
]
+
(
num_heads
*
v_dim
,))
return
output
.
contiguous
()
return
output
.
contiguous
()
...
...
transformer_engine/pytorch/module/_common.py
View file @
5cc8ee3e
...
@@ -53,7 +53,7 @@ def apply_normalization(
...
@@ -53,7 +53,7 @@ def apply_normalization(
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
normalization_func
=
_get_normalization_func
(
normalization
,
True
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
inputs
=
(
inputmat
,
ln_weight
)
if
ln_bias
is
None
else
(
inputmat
,
ln_weight
,
ln_bias
)
if
enable_lightop
and
(
ln_bias
is
None
)
and
normalization
==
"RMSNorm"
and
output_quantizer
is
None
and
(
output_dtype
is
torch
.
bfloat16
or
output_dtype
is
torch
.
float16
or
output_dtype
is
torch
.
float32
):
if
enable_lightop
and
(
ln_bias
is
None
)
and
normalization
==
"RMSNorm"
and
output_quantizer
is
None
and
(
output_dtype
is
torch
.
bfloat16
or
output_dtype
is
torch
.
float16
or
output_dtype
is
torch
.
float32
)
and
not
zero_centered_gamma
:
out
,
rsigma
=
rmsnorm_forward
(
inputmat
,
ln_weight
,
ln_out
,
eps
,
True
)
out
,
rsigma
=
rmsnorm_forward
(
inputmat
,
ln_weight
,
ln_out
,
eps
,
True
)
return
out
,
None
,
rsigma
return
out
,
None
,
rsigma
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment