Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a8d19fd9
Commit
a8d19fd9
authored
May 08, 2025
by
yuguo
Browse files
Merge branch 'main' of
http://10.6.10.68/dcutoolkit/deeplearing/TransformerEngine
parents
9d0f1c9b
6dfe66e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
2 deletions
+13
-2
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+11
-1
transformer_engine/common/gemm/rocm_gemm.cu
transformer_engine/common/gemm/rocm_gemm.cu
+2
-1
No files found.
tests/pytorch/test_numerics.py
View file @
a8d19fd9
...
@@ -1573,6 +1573,10 @@ def test_grouped_linear_accuracy(
...
@@ -1573,6 +1573,10 @@ def test_grouped_linear_accuracy(
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
weight_i
.
main_grad
=
torch
.
rand_like
(
weight_i
,
dtype
=
torch
.
float32
)
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
sequential_linear
[
i
].
weight
.
main_grad
=
weight_i
.
main_grad
.
clone
()
# Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"1"
outputs_ref
=
_test_grouped_linear_accuracy
(
outputs_ref
=
_test_grouped_linear_accuracy
(
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
)
)
...
@@ -2087,7 +2091,9 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
...
@@ -2087,7 +2091,9 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
@
pytest
.
mark
.
parametrize
(
"is_paged"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"is_paged"
,
[
False
,
True
])
def
test_kv_cache_accuracy
(
dtype
,
bs
,
model_key
,
use_RoPE
,
input_format
,
module
,
backend
,
is_paged
):
def
test_kv_cache_accuracy
(
dtype
,
bs
,
model_key
,
use_RoPE
,
input_format
,
module
,
backend
,
is_paged
):
reset_rng_states
()
reset_rng_states
()
if
backend
in
[
"FusedAttention"
]:
pytest
.
skip
(
"Not support FusedAttention"
)
if
backend
in
[
"FusedAttention"
,
"FlashAttention"
]
and
dtype
==
torch
.
float32
:
if
backend
in
[
"FusedAttention"
,
"FlashAttention"
]
and
dtype
==
torch
.
float32
:
pytest
.
skip
(
"FusedAttention and FlashAttention do not support FP32"
)
pytest
.
skip
(
"FusedAttention and FlashAttention do not support FP32"
)
if
use_RoPE
:
if
use_RoPE
:
...
@@ -2268,6 +2274,10 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
...
@@ -2268,6 +2274,10 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
grad
=
True
grad
=
True
single_output
=
False
single_output
=
False
# Force the sequential_linear and grouped_linear to use hipblaslt rather than hipblas
if
IS_HIP_EXTENSION
:
os
.
environ
[
"NVTE_FORCE_ROCM_GEMM"
]
=
"1"
for
i
in
range
(
z
):
for
i
in
range
(
z
):
general_gemm
(
general_gemm
(
A
[
i
],
A
[
i
],
...
...
transformer_engine/common/gemm/rocm_gemm.cu
View file @
a8d19fd9
...
@@ -1459,7 +1459,7 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1459,7 +1459,7 @@ void rocblas_gemm(const Tensor *inputA,
// extract the stream order alloc env
// extract the stream order alloc env
bool
stream_order_alloc
=
false
;
bool
stream_order_alloc
=
false
;
if
(
const
char
*
env_p
=
std
::
getenv
(
"ROCBLAS_STREAM_ORDER_ALLOC"
)
)
{
if
(
const
char
*
env_p
=
std
::
getenv
(
"ROCBLAS_STREAM_ORDER_ALLOC"
)
)
{
if
(
env_p
!
=
nullptr
&&
std
::
string
(
env_p
)
==
"1"
)
if
(
env_p
=
=
nullptr
||
std
::
string
(
env_p
)
==
"1"
)
stream_order_alloc
=
true
;
stream_order_alloc
=
true
;
}
}
...
@@ -1467,6 +1467,7 @@ void rocblas_gemm(const Tensor *inputA,
...
@@ -1467,6 +1467,7 @@ void rocblas_gemm(const Tensor *inputA,
NVTE_CHECK
((
A_type
==
rocblas_datatype_f16_r
&&
B_type
==
rocblas_datatype_f16_r
&&
D_type
==
rocblas_datatype_f16_r
)
||
NVTE_CHECK
((
A_type
==
rocblas_datatype_f16_r
&&
B_type
==
rocblas_datatype_f16_r
&&
D_type
==
rocblas_datatype_f16_r
)
||
(
A_type
==
rocblas_datatype_f16_r
&&
B_type
==
rocblas_datatype_f16_r
&&
D_type
==
rocblas_datatype_f32_r
)
||
(
A_type
==
rocblas_datatype_bf16_r
&&
B_type
==
rocblas_datatype_bf16_r
&&
D_type
==
rocblas_datatype_bf16_r
)
||
(
A_type
==
rocblas_datatype_bf16_r
&&
B_type
==
rocblas_datatype_bf16_r
&&
D_type
==
rocblas_datatype_bf16_r
)
||
(
A_type
==
rocblas_datatype_bf16_r
&&
B_type
==
rocblas_datatype_bf16_r
&&
D_type
==
rocblas_datatype_f32_r
)
||
(
A_type
==
rocblas_datatype_bf16_r
&&
B_type
==
rocblas_datatype_bf16_r
&&
D_type
==
rocblas_datatype_f32_r
)
||
(
A_type
==
rocblas_datatype_f32_r
&&
B_type
==
rocblas_datatype_f32_r
&&
D_type
==
rocblas_datatype_f32_r
)
||
(
A_type
==
rocblas_datatype_f32_r
&&
B_type
==
rocblas_datatype_f32_r
&&
D_type
==
rocblas_datatype_f32_r
)
||
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment