Unverified Commit 180de056 authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Atomic gemm for TP-AR and TP-RS overlap with P2P exchanges (#732)



* Atomic gemm for TP-AR and TP-RS overlap with P2P exchanges
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* FP8 reduction for atomic TP-RS with p2p exchange
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 580eb52b
......@@ -139,7 +139,10 @@ def fp8_gemm(
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
_ = fn(*args)
if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
out = fn(*args)
else:
_ = fn(*args)
return out, gelu_input
......
......@@ -623,26 +623,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
_ubuf_scale_inv_initialized = false;
_atomic_gemm = atomic_gemm;
_self_chunk_id = _tp_id;
if (_atomic_gemm) {
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({tp_size * 2}, counter_options);
counter.index_put_({Slice(None, tp_size)}, 1);
_self_chunk_id = _tp_id;
if (!is_reduce_scatter) {
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC");
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') {
printf("!!userbuffers_sendrecv_atomic\n");
} else if (env_p[0] == '2') {
printf("!!userbuffers_sendrecv_multiatomic\n");
} else if (env_p[0] == '3') {
printf("!!userbuffers_sendrecv_multiatomic_shuffle\n");
_self_chunk_id = 0;
} else {
printf("!!userbuffers_sendrecv\n");
printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n");
}
}
_self_chunk_id = 0;
counter.index_put_({_self_chunk_id}, 0);
}
}
......@@ -675,13 +669,17 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1);
const int k = (transa) ? A.size(1) : A.size(0);
const int n_chunk = _ubufs[0].size(0);
const int n = _ubuf.size(0);
const int n_chunk = n / _tp_size;
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
torch::Tensor D_buffer = torch::empty({n_chunk * (_tp_size + 1), m}, D.options());
D = torch::from_blob(D_buffer.data_ptr(), {D.size(0), D.size(1)}, D.options());
// Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
......@@ -692,100 +690,75 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
assert(pre_gelu_out.numel() == 0);
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
torch::Tensor output_chunk = torch::from_blob(output_ptr, {_ubuf.size(0), m}, D.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
for (int i = 0; i < _tp_size; i++) {
for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size;
int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size;
int send_chunk_id = i;
int recv_chunk_id = i + 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
if (i < _tp_size - 1) {
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
userbuffers_sendrecv_atomic(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes,
_ub_comm, _next_rank, _prev_rank, &counter_ptr[recv_chunk_id],
(cudaStream_t)_stream_recv);
} else if (env_p != nullptr && env_p[0] == '2') {
if (i == 0) {
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size,
counter_ptr, false, (cudaStream_t)_stream_recv);
}
} else if (env_p != nullptr && env_p[0] == '3') {
if (i == 0) {
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size,
counter_ptr, true, (cudaStream_t)_stream_recv);
}
} else {
// P2P communication
// userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset,
// comm_bytes, _ub_comm,
// _next_rank, (cudaStream_t)_stream_send);
// userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset,
// comm_bytes, _ub_comm,
// _prev_rank, (cudaStream_t)_stream_recv);
// CHECK_CUDA(cudaEventRecord(_stop_recv,
// (cudaStream_t)_stream_recv));
// CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send,
// _stop_recv, 0));
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, _ub_comm,
_next_rank, _prev_rank, (cudaStream_t)_stream_recv);
producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv);
}
const char *env_p = std::getenv("NVTE_AG_P2P_MULTI_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
if (i == 0) {
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, false, counter);
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size,
counter_ptr, true, (cudaStream_t)_stream_recv);
}
} else {
// GEMM
// userbuffers_send_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes,
// _ub_comm,
// _next_rank, _tp_size, comm_bytes, comm_bytes,
// (cudaStream_t)_stream_send);
// userbuffers_recv_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes,
// _ub_comm,
// _prev_rank, _tp_size, counter_ptr,
// (cudaStream_t)_stream_recv);
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, _next_rank, (cudaStream_t) _stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, _prev_rank, (cudaStream_t) _stream_recv);
producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv);
}
}
for (int i = 0; i < _tp_size; i++) {
if (i != _self_chunk_id) {
consumer(counter_ptr, i, (cudaStream_t)_stream_compute[0]);
if (i == 0) {
te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb,
D, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, false, counter);
}
}
at::cuda::setCurrentCUDAStream(stream_main);
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
return D;
// Store the input activation for backprop
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(),
_ubufs[_self_chunk_id].numel() *
_ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}
// Reset atomic counters
consumer_batch(counter_ptr, 1, _tp_size, (cudaStream_t)stream_main);
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.data_ptr());
CHECK_CUDA(cudaMemcpyAsync(
src_ptr + (D.numel() * D.element_size()),
src_ptr,
n_chunk * m * D.element_size(),
cudaMemcpyDeviceToDevice,
(cudaStream_t) stream_main));
// Return the last N rows of D_buffer
torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n);
return D_return;
} // atomic_gemm_overlap_ag
/*
......@@ -1018,6 +991,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
......@@ -1031,23 +1005,31 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int recv_chunk_id = send_chunk_id + _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv);
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, send_rank, (cudaStream_t) _stream_recv);
_ub_comm, send_rank, (cudaStream_t) _stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, recv_rank, (cudaStream_t) _stream_recv);
_ub_comm, recv_rank, (cudaStream_t) _stream_recv);
}
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr,
_tp_size, _ubufs[0].numel(), (cudaStream_t) stream_main);
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
}
/*
......@@ -1174,7 +1156,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS)
NVTE_ERROR("Invalid comm_type");
if (_comm_type == COMM_TYPE::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size();
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
......
......@@ -3671,6 +3671,20 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) {
}
}
// consumer_batch
static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i, int num_chunks) {
// Wait for producer to change the val to 0, which signal producer ready
if (blockIdx.x == 0 && threadIdx.x == 0) {
int old_val;
for (int i = first_chunk_i; i < num_chunks; i++) {
while (0 != (old_val = atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) {
}
((unsigned int *)atomic_ptr)[i] = 1;
asm volatile("fence.sc.gpu;\n");
}
}
}
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
......@@ -3683,6 +3697,12 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
}
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks);
}
template <typename fp8type>
__global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale,
......
......@@ -151,6 +151,7 @@ typedef struct communicator communicator;
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream);
int create_communicator(communicator **comm);
/* creates communicator, allocates all internal buffers if necessary */
......
......@@ -45,6 +45,7 @@ _cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
_amax_reduce_handle_bwd = None
layers_atomic_ring_exchange = []
def get_cublas_workspace_size_bytes() -> None:
......@@ -138,6 +139,12 @@ def initialize_ub(
}
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
# AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop":"proj_fprop", "fc1_fprop":"fc2_fprop"}
rs_ag_pairs = {v : k for k, v in ag_rs_pairs.items()}
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []
def get_method(name):
for method, names in methods.items():
if name in names:
......@@ -160,20 +167,35 @@ def initialize_ub(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
if is_reduce_scatter and method == "ring_exchange":
raise ValueError(
"Atomic GEMM is not supported for ReduceScatter with `ring_exchange` method."
)
if method == 'bulk':
warnings.warn(
"Atoimic GEMM not is supported for a bulk overlap."
f"At {name}, atoimic GEMM not is supported for a bulk overlap."
"Defaulting to `atomic_gemm=False`."
)
atomic_gemm = 0
if not is_reduce_scatter and method == 'pipeline':
raise ValueError(
"`pipeline` overlap method is not supported for AllGather."
f"At {name}, `pipeline` overlap method is not supported for AllGather."
)
# Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
# Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
global layers_atomic_ring_exchange
if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
if name in rs_ag_pairs:
assert_message = (
f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
"outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
"GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
"`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
"for functionality."
)
if name in layers_atomic_ring_exchange:
assert atomic_gemm and method == "ring_exchange", assert_message
else:
if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message
sample_buffer = torch.empty(
shape,
dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype,
......@@ -213,7 +235,7 @@ def initialize_ub(
method = ub_cfg["method"] if "method" in ub_cfg else get_method(name)
num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16
cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 4
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment