Commit 196a213f authored by yuguo's avatar yuguo
Browse files

[DCU] variable ub streams add NVTE_UB_STREAM_NUMS

parent 1312aa6e
......@@ -496,7 +496,7 @@ def _train(opts):
if opts.benchmark:
# Warmup to not profile CPU overhead
for _ in range(20):
for _ in range(opts.benchmark_iter):
if opts.use_cuda_graphs:
test_graph.replay()
else:
......
......@@ -68,10 +68,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_gemm_priority = gemm_priority;
_comm_priority = comm_priority;
}
static cudaStream_t compute_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority));
_stream_compute.push_back(std::move(stream));
if (compute_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&compute_streams[i], cudaStreamNonBlocking, _gemm_priority));
}
_stream_compute.push_back(compute_streams[i]);
}
_num_splits = num_splits;
......@@ -225,6 +227,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) {
_ub_stream_nums = num_max_streams;
_rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
......@@ -238,8 +241,12 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority));
static cudaStream_t comm_stream;
if (comm_stream == nullptr) {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&comm_stream, cudaStreamNonBlocking, _comm_priority));
}
_stream_comm = comm_stream;
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
}
......@@ -307,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapBase::bulk_overlap
......@@ -444,9 +450,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0], 1, 0, 0);
}
for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
......@@ -454,10 +466,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
......@@ -502,10 +521,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
......@@ -536,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapBase::split_overlap_rs
/***************************************************************************************************
......@@ -555,6 +580,8 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) {
_ub_stream_nums = num_max_streams;
_is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate;
......@@ -603,13 +630,19 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
}
static cudaStream_t send_streams[NVTE_COMM_OVERLAP_MAX_STREAMS];
static cudaStream_t recv_stream;
for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
_stream_send.push_back(std::move(stream));
if (send_streams[i] == nullptr) {
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&send_streams[i], cudaStreamNonBlocking, _comm_priority));
}
_stream_send.push_back(send_streams[i]);
}
if (recv_stream == nullptr) {
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&recv_stream, cudaStreamNonBlocking, _comm_priority));
}
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority));
_stream_recv = recv_stream;
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0));
}
......@@ -813,10 +846,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
if (i < num_steps - 1) {
// P2P communication
......@@ -857,10 +897,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()], 1, 0, i % _stream_compute.size());
}
if (i < _tp_size - 1) {
// P2P communication
......@@ -891,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
} // CommOverlapP2PBase::split_overlap_ag
/*
......@@ -1003,9 +1049,15 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0, stream_id);
if (_ub_stream_nums == 1) {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id]);
} else {
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id], 1, 0, stream_id);
}
if (i > 0) {
// P2P communication chunk
......@@ -1034,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
......
......@@ -15,11 +15,7 @@
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#ifdef __HIP_PLATFORM_AMD__
#define NVTE_COMM_OVERLAP_MAX_STREAMS 1
#else
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3
#endif
namespace transformer_engine {
......@@ -141,6 +137,7 @@ class CommOverlapCore {
class CommOverlapBase : public CommOverlapCore {
protected:
int _ub_stream_nums;
int _rs_kernel_type;
bool _rs_overlap_first_gemm;
cudaStream_t _stream_comm;
......@@ -206,6 +203,7 @@ class CommOverlapBase : public CommOverlapCore {
class CommOverlapP2PBase : public CommOverlapCore {
protected:
int _ub_stream_nums;
bool _is_reduce_scatter{false};
bool _use_multiatomic_ag{false};
bool _aggregate;
......
......@@ -52,10 +52,11 @@ _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_dummy_wgrads = {}
multi_stream_cublas_batchgemm_workspace = []
_multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 2 if IS_HIP_EXTENSION else 3
ub_stream_nums = int(os.getenv("NVTE_UB_STREAM_NUMS", "2"))
_NUM_MAX_UB_STREAMS = ub_stream_nums if IS_HIP_EXTENSION else 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment