Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
8fb50d09
Commit
8fb50d09
authored
Apr 14, 2025
by
yuguo
Browse files
[DCU] tmp fix
parent
b71ea424
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
12 deletions
+14
-12
tests/pytorch/distributed/run_gemm_with_overlap.py
tests/pytorch/distributed/run_gemm_with_overlap.py
+6
-6
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+6
-4
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+2
-2
No files found.
tests/pytorch/distributed/run_gemm_with_overlap.py
View file @
8fb50d09
...
...
@@ -312,7 +312,7 @@ def _main(opts):
helper
,
tp_size
,
# Tensor-parallel group size (may be different than LOCAL_SIZE)
opts
.
comm_type
,
num_max_streams
=
2
if
IS_HIP_EXTENSION
else
3
,
num_max_streams
=
1
if
IS_HIP_EXTENSION
else
3
,
set_sm_margin
=
opts
.
comm_type
==
tex
.
CommOverlapType
.
RS
or
opts
.
atomic
,
atomic_gemm
=
opts
.
atomic
,
aggregate
=
opts
.
aggregate
,
...
...
@@ -401,7 +401,7 @@ def _main(opts):
)
# Allocate cuBLAS workspace
workspace_size
=
2
*
get_cublas_workspace_size_bytes
()
workspace_size
=
1
*
get_cublas_workspace_size_bytes
()
workspace
=
torch
.
empty
(
workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
...
...
@@ -773,17 +773,17 @@ def _main(opts):
"NUMERICAL CHECK FAILED: "
+
f
"Outputs not close enough at index
{
m
.
item
()
}
"
+
f
"with
{
test_out
.
flatten
()[
m
].
item
()
}
vs
{
ref_out
.
flatten
()[
m
].
item
()
}
| "
+
f
"rel.
error
=
{
rel_err
}
(tol =
{
rtol
}
) | "
+
f
"abs.
error
=
{
abs_err
}
(tol =
{
atol
}
)"
+
f
"rel.
deviation
=
{
rel_err
}
(tol =
{
rtol
}
) | "
+
f
"abs.
deviation
=
{
abs_err
}
(tol =
{
atol
}
)"
)
else
:
numerics_info
=
"NUMERICAL CHECK PASSED: "
if
rel_err
<=
rtol
:
numerics_info
+=
f
"rel.
error
=
{
rel_err
}
(tol =
{
rtol
}
)"
+
(
numerics_info
+=
f
"rel.
deviation
=
{
rel_err
}
(tol =
{
rtol
}
)"
+
(
" | "
if
abs_err
<
atol
else
""
)
if
abs_err
<=
atol
:
numerics_info
+=
f
"abs.
error
=
{
abs_err
}
(tol =
{
atol
}
)"
numerics_info
+=
f
"abs.
deviation
=
{
abs_err
}
(tol =
{
atol
}
)"
dist_print
(
numerics_info
,
src
=
0
,
section
=
True
,
info
=
True
,
error
=
numerics_failed
,
group
=
tp_group
...
...
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
8fb50d09
...
...
@@ -3,6 +3,8 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# UB_SKIPMC=1 mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_layer_with_overlap.py --seed=42 --seq-length=4096 --batch-size=2 --num-heads=96 --head-dim=128 --layer-type LayerNormLinear --linear-parallel-mode column --num-layers 1 --overlap-rs-dgrad
# NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=1 UB_SKIPMC=1 mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_layer_with_overlap.py --seed=42 --seq-length=4096 --batch-size=2 --num-heads=96 --head-dim=128 --layer-type MultiheadAttention --num-layers 1 --overlap-rs-dgrad
import
os
import
sys
...
...
@@ -266,17 +268,17 @@ def _compare_tensors(name, test, ref, rtol, atol):
"NUMERICAL CHECK FAILED: "
+
f
"
{
name
}
not close enough at index
{
m
.
item
()
}
"
+
f
"with
{
test
.
flatten
()[
m
].
item
()
}
vs
{
ref
.
flatten
()[
m
].
item
()
}
| "
+
f
"rel.
error
=
{
rel_err
}
(tol =
{
rtol
}
) | "
+
f
"abs.
error
=
{
abs_err
}
(tol =
{
atol
}
)"
+
f
"rel.
deviation
=
{
rel_err
}
(tol =
{
rtol
}
) | "
+
f
"abs.
deviation
=
{
abs_err
}
(tol =
{
atol
}
)"
)
else
:
numerics_info
=
f
"NUMERICAL CHECK PASSED:
{
name
}
| "
if
rel_err
<=
rtol
:
numerics_info
+=
f
"rel.
error
=
{
rel_err
}
(tol =
{
rtol
}
)"
+
(
numerics_info
+=
f
"rel.
deviation
=
{
rel_err
}
(tol =
{
rtol
}
)"
+
(
" | "
if
abs_err
<=
atol
else
"."
)
if
abs_err
<=
atol
:
numerics_info
+=
f
" abs.
error
=
{
abs_err
}
(tol =
{
atol
}
)"
numerics_info
+=
f
" abs.
deviation
=
{
abs_err
}
(tol =
{
atol
}
)"
return
numerics_failed
,
numerics_info
...
...
transformer_engine/pytorch/module/base.py
View file @
8fb50d09
...
...
@@ -47,7 +47,7 @@ _multi_stream_cublas_workspace = []
_multi_stream_cublas_batchgemm_workspace
=
[]
_cublas_workspace
=
None
_ub_communicators
=
None
_NUM_MAX_UB_STREAMS
=
2
if
IS_HIP_EXTENSION
else
3
_NUM_MAX_UB_STREAMS
=
1
if
IS_HIP_EXTENSION
else
3
_MIN_STREAM_PRIORITY
,
_MAX_STREAM_PRIORITY
=
None
,
None
layers_atomic_ring_exchange
=
[]
...
...
@@ -357,7 +357,7 @@ def initialize_ub(
helper
,
# Helper for torch.distributed callbacks during bootstrapping
tp_size
,
# Tensor-parallel group size (may be different than local_size)
num_splits
=
num_splits
,
num_max_streams
=
_NUM_MAX_UB_STREAMS
-
1
if
IS_HIP_EXTENSION
else
_NUM_MAX_UB_STREAMS
,
num_max_streams
=
_NUM_MAX_UB_STREAMS
,
comm_cga_size
=
cga_size
,
num_comm_sm
=
num_sm
,
set_sm_margin
=
set_sm_margin
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment