Unverified Commit 180de056 authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Atomic gemm for TP-AR and TP-RS overlap with P2P exchanges (#732)



* Atomic gemm for TP-AR and TP-RS overlap with P2P exchanges
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* FP8 reduction for atomic TP-RS with p2p exchange
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 580eb52b
...@@ -139,6 +139,9 @@ def fp8_gemm( ...@@ -139,6 +139,9 @@ def fp8_gemm(
extra_output_tensor is not None extra_output_tensor is not None
), 'ATOMIC_GEMM_RS_P2P requires extra output tensor' ), 'ATOMIC_GEMM_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,)) args = tuple(args + (extra_output_tensor,))
if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
out = fn(*args)
else:
_ = fn(*args) _ = fn(*args)
return out, gelu_input return out, gelu_input
......
...@@ -623,26 +623,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -623,26 +623,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
_ubuf_scale_inv_initialized = false; _ubuf_scale_inv_initialized = false;
_atomic_gemm = atomic_gemm; _atomic_gemm = atomic_gemm;
_self_chunk_id = _tp_id;
if (_atomic_gemm) { if (_atomic_gemm) {
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({tp_size * 2}, counter_options); counter = torch::zeros({tp_size * 2}, counter_options);
counter.index_put_({Slice(None, tp_size)}, 1); counter.index_put_({Slice(None, tp_size)}, 1);
_self_chunk_id = _tp_id;
if (!is_reduce_scatter) { if (!is_reduce_scatter) {
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC"); const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (rank == 0 && env_p != nullptr) { if (rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') { if (env_p[0] == '1') {
printf("!!userbuffers_sendrecv_atomic\n"); printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n");
} else if (env_p[0] == '2') {
printf("!!userbuffers_sendrecv_multiatomic\n");
} else if (env_p[0] == '3') {
printf("!!userbuffers_sendrecv_multiatomic_shuffle\n");
_self_chunk_id = 0;
} else {
printf("!!userbuffers_sendrecv\n");
} }
} }
_self_chunk_id = 0;
counter.index_put_({_self_chunk_id}, 0); counter.index_put_({_self_chunk_id}, 0);
} }
} }
...@@ -675,13 +669,17 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -675,13 +669,17 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Get GEMM dimensions between TN and NN input layouts // Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1); const int m = (transa) ? A.size(0) : A.size(1);
const int k = (transa) ? A.size(1) : A.size(0); const int k = (transa) ? A.size(1) : A.size(0);
const int n_chunk = _ubufs[0].size(0); const int n = _ubuf.size(0);
const int n_chunk = n / _tp_size;
// Get communication and GEMM output chunk sizes // Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
torch::Tensor D_buffer = torch::empty({n_chunk * (_tp_size + 1), m}, D.options());
D = torch::from_blob(D_buffer.data_ptr(), {D.size(0), D.size(1)}, D.options());
// Get output and workspace data pointers // Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr()); char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr()); int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size(); int workspace_size_chunk = workspaceSize / _stream_compute.size();
...@@ -692,100 +690,75 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -692,100 +690,75 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (B_scale_inverse.numel()) if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor]; B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
torch::Tensor output_chunk = torch::from_blob(output_ptr, {_ubuf.size(0), m}, D.options());
torch::Tensor workspace_chunk = torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
for (int i = 0; i < _tp_size; i++) {
for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current // Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to // GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring // have the AG output in all ranks to be contiguous after the ring
// exchanges // exchanges
int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size; int send_chunk_id = i;
int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size; int recv_chunk_id = i + 1;
int send_offset = comm_bytes * send_chunk_id; int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id; int recv_offset = comm_bytes * recv_chunk_id;
if (i < _tp_size - 1) { const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') { if (env_p != nullptr && env_p[0] == '1') {
userbuffers_sendrecv_atomic(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes,
_ub_comm, _next_rank, _prev_rank, &counter_ptr[recv_chunk_id],
(cudaStream_t)_stream_recv);
} else if (env_p != nullptr && env_p[0] == '2') {
if (i == 0) {
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size,
counter_ptr, false, (cudaStream_t)_stream_recv);
}
} else if (env_p != nullptr && env_p[0] == '3') {
if (i == 0) { if (i == 0) {
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes, userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size, _ub_comm, _next_rank, _prev_rank, _tp_size,
counter_ptr, true, (cudaStream_t)_stream_recv); counter_ptr, true, (cudaStream_t)_stream_recv);
} }
} else { } else {
// P2P communication userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
// userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, _ub_comm, _next_rank, (cudaStream_t) _stream_recv);
// comm_bytes, _ub_comm, userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
// _next_rank, (cudaStream_t)_stream_send); _ub_comm, _prev_rank, (cudaStream_t) _stream_recv);
// userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset,
// comm_bytes, _ub_comm,
// _prev_rank, (cudaStream_t)_stream_recv);
// CHECK_CUDA(cudaEventRecord(_stop_recv,
// (cudaStream_t)_stream_recv));
// CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send,
// _stop_recv, 0));
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, _ub_comm,
_next_rank, _prev_rank, (cudaStream_t)_stream_recv);
producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv); producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv);
} }
if (i == 0) { if (i == 0) {
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb, te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, D, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, false, counter); _math_sms, 0, _tp_size, false, counter);
} }
} else { }
// GEMM
// userbuffers_send_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes, // Store the input activation for backprop
// _ub_comm,
// _next_rank, _tp_size, comm_bytes, comm_bytes,
// (cudaStream_t)_stream_send);
// userbuffers_recv_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes,
// _ub_comm,
// _prev_rank, _tp_size, counter_ptr,
// (cudaStream_t)_stream_recv);
if (B_copy.numel() > 0) { if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel()); assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(), CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), _ubufs[_self_chunk_id].numel() *
_ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send)); cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send)); CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
} }
}
}
for (int i = 0; i < _tp_size; i++) {
if (i != _self_chunk_id) {
consumer(counter_ptr, i, (cudaStream_t)_stream_compute[0]);
}
}
at::cuda::setCurrentCUDAStream(stream_main);
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
return D; // Reset atomic counters
consumer_batch(counter_ptr, 1, _tp_size, (cudaStream_t)stream_main);
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.data_ptr());
CHECK_CUDA(cudaMemcpyAsync(
src_ptr + (D.numel() * D.element_size()),
src_ptr,
n_chunk * m * D.element_size(),
cudaMemcpyDeviceToDevice,
(cudaStream_t) stream_main));
// Return the last N rows of D_buffer
torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n);
return D_return;
} // atomic_gemm_overlap_ag } // atomic_gemm_overlap_ag
/* /*
...@@ -1018,6 +991,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1018,6 +991,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
// Atomic GEMM // Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
torch::Tensor workspace_chunk = torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
...@@ -1031,8 +1005,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1031,8 +1005,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int recv_chunk_id = send_chunk_id + _tp_size; int recv_chunk_id = send_chunk_id + _tp_size;
int send_offset = comm_bytes * send_chunk_id; int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id; int recv_offset = comm_bytes * recv_chunk_id;
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv); consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv);
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
...@@ -1045,10 +1019,18 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1045,10 +1019,18 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Reduce GEMM output chunks // Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr()); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr,
_tp_size, _ubufs[0].numel(), (cudaStream_t) stream_main);
} else {
torch::Tensor reduce_buf = torch::from_blob( torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0); torch::sum_out(rs_output, reduce_buf, 0);
} }
}
/* /*
** Split ReduceScatter + GEMM using P2P communication ** Split ReduceScatter + GEMM using P2P communication
...@@ -1174,7 +1156,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1174,7 +1156,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS)
NVTE_ERROR("Invalid comm_type"); NVTE_ERROR("Invalid comm_type");
if (_comm_type == COMM_TYPE::RS) if (_comm_type == COMM_TYPE::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size();
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1); int output_c_dim1 = _ubuf.size(1);
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
......
...@@ -3671,6 +3671,20 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) { ...@@ -3671,6 +3671,20 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) {
} }
} }
// consumer_batch
static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i, int num_chunks) {
// Wait for producer to change the val to 0, which signal producer ready
if (blockIdx.x == 0 && threadIdx.x == 0) {
int old_val;
for (int i = first_chunk_i; i < num_chunks; i++) {
while (0 != (old_val = atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) {
}
((unsigned int *)atomic_ptr)[i] = 1;
asm volatile("fence.sc.gpu;\n");
}
}
}
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1); dim3 block(1);
dim3 grid(1); dim3 grid(1);
...@@ -3683,6 +3697,12 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { ...@@ -3683,6 +3697,12 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i); consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
} }
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks);
}
template <typename fp8type> template <typename fp8type>
__global__ void __launch_bounds__(MAX_THREADS / 4) __global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale,
......
...@@ -151,6 +151,7 @@ typedef struct communicator communicator; ...@@ -151,6 +151,7 @@ typedef struct communicator communicator;
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream); void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream);
int create_communicator(communicator **comm); int create_communicator(communicator **comm);
/* creates communicator, allocates all internal buffers if necessary */ /* creates communicator, allocates all internal buffers if necessary */
......
...@@ -45,6 +45,7 @@ _cublas_workspace = None ...@@ -45,6 +45,7 @@ _cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 3 _NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None _amax_reduce_handle_bwd = None
layers_atomic_ring_exchange = []
def get_cublas_workspace_size_bytes() -> None: def get_cublas_workspace_size_bytes() -> None:
...@@ -138,6 +139,12 @@ def initialize_ub( ...@@ -138,6 +139,12 @@ def initialize_ub(
} }
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
# AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop":"proj_fprop", "fc1_fprop":"fc2_fprop"}
rs_ag_pairs = {v : k for k, v in ag_rs_pairs.items()}
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []
def get_method(name): def get_method(name):
for method, names in methods.items(): for method, names in methods.items():
if name in names: if name in names:
...@@ -160,20 +167,35 @@ def initialize_ub( ...@@ -160,20 +167,35 @@ def initialize_ub(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases." "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
) )
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
if is_reduce_scatter and method == "ring_exchange":
raise ValueError(
"Atomic GEMM is not supported for ReduceScatter with `ring_exchange` method."
)
if method == 'bulk': if method == 'bulk':
warnings.warn( warnings.warn(
"Atoimic GEMM not is supported for a bulk overlap." f"At {name}, atoimic GEMM not is supported for a bulk overlap."
"Defaulting to `atomic_gemm=False`." "Defaulting to `atomic_gemm=False`."
) )
atomic_gemm = 0 atomic_gemm = 0
if not is_reduce_scatter and method == 'pipeline': if not is_reduce_scatter and method == 'pipeline':
raise ValueError( raise ValueError(
"`pipeline` overlap method is not supported for AllGather." f"At {name}, `pipeline` overlap method is not supported for AllGather."
) )
# Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
# Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
global layers_atomic_ring_exchange
if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
if name in rs_ag_pairs:
assert_message = (
f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
"outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
"GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
"`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
"for functionality."
)
if name in layers_atomic_ring_exchange:
assert atomic_gemm and method == "ring_exchange", assert_message
else:
if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message
sample_buffer = torch.empty( sample_buffer = torch.empty(
shape, shape,
dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype, dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype,
...@@ -213,7 +235,7 @@ def initialize_ub( ...@@ -213,7 +235,7 @@ def initialize_ub(
method = ub_cfg["method"] if "method" in ub_cfg else get_method(name) method = ub_cfg["method"] if "method" in ub_cfg else get_method(name)
num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16 num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16
cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2 cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0 num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 4
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0 set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0 aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0 atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment