Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
196a213f
Commit
196a213f
authored
May 20, 2025
by
yuguo
Browse files
[DCU] variable ub streams add NVTE_UB_STREAM_NUMS
parent
1312aa6e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
82 additions
and
32 deletions
+82
-32
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+1
-1
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+76
-25
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
...ine/common/include/transformer_engine/comm_gemm_overlap.h
+2
-4
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+3
-2
No files found.
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
196a213f
...
...
@@ -496,7 +496,7 @@ def _train(opts):
if
opts
.
benchmark
:
# Warmup to not profile CPU overhead
for
_
in
range
(
20
):
for
_
in
range
(
opts
.
benchmark_iter
):
if
opts
.
use_cuda_graphs
:
test_graph
.
replay
()
else
:
...
...
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
196a213f
...
...
@@ -68,10 +68,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_gemm_priority
=
gemm_priority
;
_comm_priority
=
comm_priority
;
}
static
cudaStream_t
compute_streams
[
NVTE_COMM_OVERLAP_MAX_STREAMS
];
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
num_splits
);
i
++
)
{
cudaStream_t
stream
;
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
stream
,
cudaStreamNonBlocking
,
_gemm_priority
));
_stream_compute
.
push_back
(
std
::
move
(
stream
));
if
(
compute_streams
[
i
]
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams
[
i
],
cudaStreamNonBlocking
,
_gemm_priority
));
}
_stream_compute
.
push_back
(
compute_streams
[
i
]);
}
_num_splits
=
num_splits
;
...
...
@@ -225,6 +227,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle
,
barrier_handle
,
num_splits
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
false
,
atomic_gemm
)
{
_ub_stream_nums
=
num_max_streams
;
_rs_overlap_first_gemm
=
rs_overlap_first_gemm
;
_rs_kernel_type
=
getenv
<
int
>
(
"NVTE_RS_STRIDED_ATOMIC"
,
0
);
NVTE_CHECK
(
_rs_kernel_type
>=
0
&&
_rs_kernel_type
<=
3
,
...
...
@@ -238,8 +241,12 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if
(
_ub_comm
->
myrank
==
0
)
printf
(
"!!! [UB] Register UBuf %d
\n
"
,
_ub_reg
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
buffer_shape
,
buffer_dtype
);
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
_stream_comm
,
cudaStreamNonBlocking
,
_comm_priority
));
static
cudaStream_t
comm_stream
;
if
(
comm_stream
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
comm_stream
,
cudaStreamNonBlocking
,
_comm_priority
));
}
_stream_comm
=
comm_stream
;
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_start_d2dcopy
,
0
));
}
...
...
@@ -307,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_compute
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
}
// CommOverlapBase::bulk_overlap
...
...
@@ -444,9 +450,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
0
,
{
m
,
m_chunk
});
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
}
else
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
],
1
,
0
,
0
);
}
for
(
int
i
=
1
;
i
<
_num_splits
;
i
++
)
{
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
...
...
@@ -454,10 +466,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
}
else
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[(
i
-
1
)
%
_stream_compute
.
size
()]));
...
...
@@ -502,10 +521,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
}
else
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
i
%
_stream_compute
.
size
()]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
...
...
@@ -536,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
}
// CommOverlapBase::split_overlap_rs
/***************************************************************************************************
...
...
@@ -555,6 +580,8 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle
,
barrier_handle
,
tp_size
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
use_ce
,
atomic_gemm
)
{
_ub_stream_nums
=
num_max_streams
;
_is_p2p
=
true
;
_is_reduce_scatter
=
comm_type
==
CommOverlapType
::
RS
;
_aggregate
=
aggregate
;
...
...
@@ -603,13 +630,19 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
NVTE_CHECK_CUDA
(
cudaMemset
(
_counter
.
dptr
(),
0
,
sizeof
(
int32_t
)));
}
static
cudaStream_t
send_streams
[
NVTE_COMM_OVERLAP_MAX_STREAMS
];
static
cudaStream_t
recv_stream
;
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
_tp_size
);
i
++
)
{
cudaStream_t
stream
;
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
stream
,
cudaStreamNonBlocking
,
_comm_priority
));
_stream_send
.
push_back
(
std
::
move
(
stream
));
if
(
send_streams
[
i
]
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
send_streams
[
i
],
cudaStreamNonBlocking
,
_comm_priority
));
}
_stream_send
.
push_back
(
send_streams
[
i
]);
}
if
(
recv_stream
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
recv_stream
,
cudaStreamNonBlocking
,
_comm_priority
));
}
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
_stream_recv
,
cudaStreamNonBlocking
,
_comm_priority
));
_stream_recv
=
recv_stream
;
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_stop_recv
,
0
));
}
...
...
@@ -813,10 +846,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
}
else
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
}
if
(
i
<
num_steps
-
1
)
{
// P2P communication
...
...
@@ -857,10 +897,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
}
else
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
}
if
(
i
<
_tp_size
-
1
)
{
// P2P communication
...
...
@@ -891,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
}
// CommOverlapP2PBase::split_overlap_ag
/*
...
...
@@ -1003,9 +1049,15 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
stream_id
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
],
1
,
0
,
stream_id
);
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
]);
}
else
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
],
1
,
0
,
stream_id
);
}
if
(
i
>
0
)
{
// P2P communication chunk
...
...
@@ -1034,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
// Reduce GEMM output chunks
char
*
reduce_buf_ptr
=
reinterpret_cast
<
char
*>
(
_ubufs
[
_tp_size
-
1
].
dptr
());
...
...
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
View file @
196a213f
...
...
@@ -15,11 +15,7 @@
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#ifdef __HIP_PLATFORM_AMD__
#define NVTE_COMM_OVERLAP_MAX_STREAMS 1
#else
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3
#endif
namespace
transformer_engine
{
...
...
@@ -141,6 +137,7 @@ class CommOverlapCore {
class
CommOverlapBase
:
public
CommOverlapCore
{
protected:
int
_ub_stream_nums
;
int
_rs_kernel_type
;
bool
_rs_overlap_first_gemm
;
cudaStream_t
_stream_comm
;
...
...
@@ -206,6 +203,7 @@ class CommOverlapBase : public CommOverlapCore {
class
CommOverlapP2PBase
:
public
CommOverlapCore
{
protected:
int
_ub_stream_nums
;
bool
_is_reduce_scatter
{
false
};
bool
_use_multiatomic_ag
{
false
};
bool
_aggregate
;
...
...
transformer_engine/pytorch/module/base.py
View file @
196a213f
...
...
@@ -52,10 +52,11 @@ _2X_ACC_DGRAD = True
_2X_ACC_WGRAD
=
True
_multi_stream_cublas_workspace
=
[]
_dummy_wgrads
=
{}
multi_stream_cublas_batchgemm_workspace
=
[]
_
multi_stream_cublas_batchgemm_workspace
=
[]
_cublas_workspace
=
None
_ub_communicators
=
None
_NUM_MAX_UB_STREAMS
=
2
if
IS_HIP_EXTENSION
else
3
ub_stream_nums
=
int
(
os
.
getenv
(
"NVTE_UB_STREAM_NUMS"
,
"2"
))
_NUM_MAX_UB_STREAMS
=
ub_stream_nums
if
IS_HIP_EXTENSION
else
3
_MIN_STREAM_PRIORITY
,
_MAX_STREAM_PRIORITY
=
None
,
None
layers_atomic_ring_exchange
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment