Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
86f2e9a9
Commit
86f2e9a9
authored
Apr 29, 2025
by
wenjh
Browse files
[PytorchUnitTest] Fix errors while running tests
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
11b6b7e4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
6 deletions
+18
-6
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+5
-2
transformer_engine/common/permutation/permutation.cu
transformer_engine/common/permutation/permutation.cu
+8
-0
transformer_engine/pytorch/csrc/extensions/attention.cu
transformer_engine/pytorch/csrc/extensions/attention.cu
+2
-1
transformer_engine/pytorch/triton/cross_entropy.py
transformer_engine/pytorch/triton/cross_entropy.py
+3
-3
No files found.
tests/pytorch/test_numerics.py
View file @
86f2e9a9
...
...
@@ -39,6 +39,7 @@ from transformer_engine.pytorch import (
Fp8Padding
,
Fp8Unpadding
,
)
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.dot_product_attention.inference
import
InferenceParams
from
transformer_engine.pytorch.distributed
import
checkpoint
as
te_checkpoint
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
,
general_grouped_gemm
...
...
@@ -61,8 +62,10 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state
=
torch
.
get_rng_state
()
_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
_dynamo
.
config
.
recompile_limit
=
16
if
torch_version
()
>=
(
2
,
7
,
0
):
torch
.
_dynamo
.
config
.
recompile_limit
=
16
else
:
torch
.
_dynamo
.
config
.
cache_size_limit
=
16
class
ModelConfig
:
def
__init__
(
self
,
hidden_size
,
eps
,
num_attention_heads
,
embed
,
num_layers
,
seq_len
):
...
...
transformer_engine/common/permutation/permutation.cu
View file @
86f2e9a9
...
...
@@ -253,7 +253,11 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
num_out_tokens
);
blocks
=
num_rows
;
#ifdef __HIP_PLATFORM_AMD__
threads
=
std
::
min
(
num_cols
/
kElementsPerAccess
,
256
);
#else
threads
=
std
::
min
(
num_cols
/
kElementsPerAccess
,
1024
);
#endif
moe_permute_kernel
<
T
,
TCompute
,
128
,
false
><<<
blocks
,
threads
,
0
,
stream
>>>
(
input
,
nullptr
,
output
,
nullptr
,
nullptr
,
row_id_map
,
num_rows
,
topK
,
num_cols
);
}
else
{
...
...
@@ -305,7 +309,11 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f
static
constexpr
int
kElementsPerAccess
=
16
/
sizeof
(
T
);
int
blocks
=
num_rows
;
#ifdef __HIP_PLATFORM_AMD__
int
threads
=
std
::
min
(
num_cols
/
kElementsPerAccess
,
256
);
#else
int
threads
=
std
::
min
(
num_cols
/
kElementsPerAccess
,
1024
);
#endif
size_t
smem_bytes
=
topK
*
sizeof
(
TCompute
);
if
(
prob
==
nullptr
)
{
...
...
transformer_engine/pytorch/csrc/extensions/attention.cu
View file @
86f2e9a9
...
...
@@ -18,7 +18,8 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
#ifdef __HIP_PLATFORM_AMD__
assert
(
false
);
// assert(false);
return
NVTE_Fused_Attn_Backend
::
NVTE_No_Backend
;
#else
NVTE_Fused_Attn_Backend
fused_attention_backend
=
nvte_get_fused_attn_backend
(
static_cast
<
NVTEDType
>
(
q_dtype
),
static_cast
<
NVTEDType
>
(
kv_dtype
),
qkv_layout
,
bias_type
,
...
...
transformer_engine/pytorch/triton/cross_entropy.py
View file @
86f2e9a9
...
...
@@ -281,7 +281,7 @@ def cross_entropy_forward(
rank
=
rank
,
n_cols
=
V
,
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
32
,
num_warps
=
16
if
IS_HIP_EXTENSION
else
32
,
)
world_size
=
1
if
dist_process_group
is
None
else
dist
.
get_world_size
(
dist_process_group
)
...
...
@@ -309,7 +309,7 @@ def cross_entropy_forward(
n_non_ignore
=
n_rows
,
label_smoothing
=
label_smoothing
,
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
32
,
num_warps
=
16
if
IS_HIP_EXTENSION
else
32
,
)
loss
=
torch
.
reshape
(
loss_1d
,
(
B
,
SQ
))
if
not
reduce_loss
else
(
torch
.
sum
(
loss_1d
)
/
n_rows
)
...
...
@@ -335,7 +335,7 @@ def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor):
grad_output
,
V
,
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
32
,
num_warps
=
16
if
IS_HIP_EXTENSION
else
32
,
)
return
_input
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment