Commit 196a213f authored by yuguo's avatar yuguo
Browse files

[DCU] variable ub streams add NVTE_UB_STREAM_NUMS

parent 1312aa6e
...@@ -496,7 +496,7 @@ def _train(opts): ...@@ -496,7 +496,7 @@ def _train(opts):
if opts.benchmark: if opts.benchmark:
# Warmup to not profile CPU overhead # Warmup to not profile CPU overhead
for _ in range(20): for _ in range(opts.benchmark_iter):
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
test_graph.replay() test_graph.replay()
else: else:
......
...@@ -68,10 +68,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl ...@@ -68,10 +68,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_gemm_priority = gemm_priority; _gemm_priority = gemm_priority;
_comm_priority = comm_priority; _comm_priority = comm_priority;
} }
static cudaStream_t compute_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream; if (compute_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority)); NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority));
_stream_compute.push_back(std::move(stream)); }
_stream_compute.push_back(compute_streams[i]);
} }
_num_splits = num_splits; _num_splits = num_splits;
...@@ -225,6 +227,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -225,6 +227,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) { atomic_gemm) {
_ub_stream_nums = num_max_streams;
_rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0); _rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
...@@ -238,8 +241,12 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -238,8 +241,12 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA( static cudaStream_t comm_stream;
cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority)); if (comm_stream == nullptr) {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority));
}
_stream_comm = comm_stream;
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
} }
...@@ -307,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te ...@@ -307,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0])); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapBase::bulk_overlap } // CommOverlapBase::bulk_overlap
...@@ -444,9 +450,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -444,9 +450,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk}); auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0], 1, 0, 0); use_split_accumulator, _math_sms, _stream_compute[0], 1, 0, 0);
}
for (int i = 1; i < _num_splits; i++) { for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
...@@ -454,10 +466,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -454,10 +466,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
workspace_chunk = get_tensor_chunk( workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
accumulate, use_split_accumulator, _math_sms, pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size()); accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
...@@ -502,10 +521,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -502,10 +521,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size()); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()])); NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
...@@ -536,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -536,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapBase::split_overlap_rs } // CommOverlapBase::split_overlap_rs
/*************************************************************************************************** /***************************************************************************************************
...@@ -555,6 +580,8 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -555,6 +580,8 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) { atomic_gemm) {
_ub_stream_nums = num_max_streams;
_is_p2p = true; _is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS; _is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate; _aggregate = aggregate;
...@@ -603,13 +630,19 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, ...@@ -603,13 +630,19 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
} }
static cudaStream_t send_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
static cudaStream_t recv_stream;
for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) {
cudaStream_t stream; if (send_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&send_streams[i], cudaStreamNonBlocking, _comm_priority));
_stream_send.push_back(std::move(stream)); }
_stream_send.push_back(send_streams[i]);
}
if (recv_stream == nullptr) {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&recv_stream, cudaStreamNonBlocking, _comm_priority));
} }
NVTE_CHECK_CUDA( _stream_recv = recv_stream;
cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0));
} }
...@@ -813,10 +846,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -813,10 +846,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size()); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
if (i < num_steps - 1) { if (i < num_steps - 1) {
// P2P communication // P2P communication
...@@ -857,10 +897,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -857,10 +897,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk( auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size()); _stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
if (i < _tp_size - 1) { if (i < _tp_size - 1) {
// P2P communication // P2P communication
...@@ -891,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, ...@@ -891,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapP2PBase::split_overlap_ag } // CommOverlapP2PBase::split_overlap_ag
/* /*
...@@ -1003,9 +1049,15 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1003,9 +1049,15 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto workspace_chunk = auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), if (_ub_stream_nums == 1) {
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0, stream_id); pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0, stream_id);
}
if (i > 0) { if (i > 0) {
// P2P communication chunk // P2P communication chunk
...@@ -1034,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, ...@@ -1034,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
} }
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
// Reduce GEMM output chunks // Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr()); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
......
...@@ -15,11 +15,7 @@ ...@@ -15,11 +15,7 @@
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" #include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#ifdef __HIP_PLATFORM_AMD__
#define NVTE_COMM_OVERLAP_MAX_STREAMS 1
#else
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 #define NVTE_COMM_OVERLAP_MAX_STREAMS 3
#endif
namespace transformer_engine { namespace transformer_engine {
...@@ -141,6 +137,7 @@ class CommOverlapCore { ...@@ -141,6 +137,7 @@ class CommOverlapCore {
class CommOverlapBase : public CommOverlapCore { class CommOverlapBase : public CommOverlapCore {
protected: protected:
int _ub_stream_nums;
int _rs_kernel_type; int _rs_kernel_type;
bool _rs_overlap_first_gemm; bool _rs_overlap_first_gemm;
cudaStream_t _stream_comm; cudaStream_t _stream_comm;
...@@ -206,6 +203,7 @@ class CommOverlapBase : public CommOverlapCore { ...@@ -206,6 +203,7 @@ class CommOverlapBase : public CommOverlapCore {
class CommOverlapP2PBase : public CommOverlapCore { class CommOverlapP2PBase : public CommOverlapCore {
protected: protected:
int _ub_stream_nums;
bool _is_reduce_scatter{false}; bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false}; bool _use_multiatomic_ag{false};
bool _aggregate; bool _aggregate;
......
...@@ -52,10 +52,11 @@ _2X_ACC_DGRAD = True ...@@ -52,10 +52,11 @@ _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True _2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = [] _multi_stream_cublas_workspace = []
_dummy_wgrads = {} _dummy_wgrads = {}
multi_stream_cublas_batchgemm_workspace = [] _multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None _cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 2 if IS_HIP_EXTENSION else 3 ub_stream_nums = int(os.getenv("NVTE_UB_STREAM_NUMS", "2"))
_NUM_MAX_UB_STREAMS = ub_stream_nums if IS_HIP_EXTENSION else 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment