Unverified Commit b855656b authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

TP-RS overlap with send/recv ring-exchange (#724)



* TP-RS overlap with send/recv

Atomic GEMM based TP-RS overlap with send/recv
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

Specify userbuffer overlap method of each overlap instance
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

P2P TP-RS overlap with fp8 GEMM outputs
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

Fix TP-RS overlap with send/recv
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* linting
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix typo
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 59bfc17b
...@@ -3175,10 +3175,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3175,10 +3175,8 @@ class MultiheadAttention(torch.nn.Module):
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False, ub_overlap_rs: bool = False,
ub_split_ag: bool = False, ub_overlap_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
bias: bool = True, bias: bool = True,
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
...@@ -3265,9 +3263,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3265,9 +3263,8 @@ class MultiheadAttention(torch.nn.Module):
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag, ub_overlap_ag=ub_overlap_ag,
normalization=normalization, normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv", ub_name="qkv",
**common_gemm_kwargs, **common_gemm_kwargs,
) )
...@@ -3297,9 +3294,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3297,9 +3294,8 @@ class MultiheadAttention(torch.nn.Module):
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag, ub_overlap_ag=ub_overlap_ag,
normalization=normalization, normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv", ub_name="qkv",
**common_gemm_kwargs, **common_gemm_kwargs,
) )
...@@ -3347,10 +3343,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -3347,10 +3343,8 @@ class MultiheadAttention(torch.nn.Module):
bias=bias, bias=bias,
return_bias=return_bias, return_bias=return_bias,
parallel_mode="row" if set_parallel_mode else None, parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs, ub_overlap_rs=ub_overlap_rs,
ub_split_ag=ub_split_ag, ub_overlap_ag=ub_overlap_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="proj", ub_name="proj",
**common_gemm_kwargs, **common_gemm_kwargs,
) )
......
...@@ -101,14 +101,14 @@ def fp8_gemm( ...@@ -101,14 +101,14 @@ def fp8_gemm(
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
) )
args = tuple(args + (0, extra_output_tensor,)) args = tuple(args + (0, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag fn = ub.split_overlap_ag_p2p
extra_output_tensor = ( extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
) )
args = tuple(args + (extra_output_tensor,)) args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG: elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
fn = ub.atomic_gemm_overlap_ag fn = ub.atomic_gemm_overlap_ag_p2p
extra_output_tensor = ( extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
) )
...@@ -119,12 +119,24 @@ def fp8_gemm( ...@@ -119,12 +119,24 @@ def fp8_gemm(
extra_output_tensor is not None extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor' ), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,)) args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS: elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs fn = ub.atomic_gemm_overlap_rs
assert ( assert (
extra_output_tensor is not None extra_output_tensor is not None
), 'ATOMIC_GEMM_RS requires extra output tensor' ), 'ATOMIC_GEMM_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,)) args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P:
fn = ub.atomic_gemm_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
_ = fn(*args) _ = fn(*args)
return out, gelu_input return out, gelu_input
...@@ -217,8 +229,8 @@ def gemm( ...@@ -217,8 +229,8 @@ def gemm(
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap fn = ub.bulk_overlap
args = tuple(args + (0, empty_tensor)) args = tuple(args + (0, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag fn = ub.split_overlap_ag_p2p
extra_output_tensor = ( extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor empty_tensor if extra_output_tensor is None else extra_output_tensor
) )
...@@ -229,6 +241,12 @@ def gemm( ...@@ -229,6 +241,12 @@ def gemm(
extra_output_tensor is not None extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor' ), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (False, extra_output_tensor,)) args = tuple(args + (False, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
_ = fn(*args) _ = fn(*args)
return out, grad_bias, gelu_input return out, grad_bias, gelu_input
...@@ -41,10 +41,12 @@ enum class COMM_TYPE { RS = 0, AG = 1 }; ...@@ -41,10 +41,12 @@ enum class COMM_TYPE { RS = 0, AG = 1 };
enum class UBOverlapAlgo { enum class UBOverlapAlgo {
BULK_OVERLAP_AG = 0, BULK_OVERLAP_AG = 0,
BULK_OVERLAP_RS = 1, BULK_OVERLAP_RS = 1,
SPLIT_PIPELINED_AG = 2, SPLIT_PIPELINED_AG_P2P = 2,
SPLIT_PIPELINED_RS = 3, SPLIT_PIPELINED_RS = 3,
ATOMIC_GEMM_RS = 4, SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_AG = 5 ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7
}; };
struct UbufBase { struct UbufBase {
...@@ -70,9 +72,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -70,9 +72,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int comm_sms; int comm_sms;
int cga_size; int cga_size;
int use_ce; int use_ce;
bool _atomic_gemm;
UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size, UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size,
int num_splits, bool set_sm_margin, int num_max_streams, int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm,
torch::Tensor empty_tensor) { torch::Tensor empty_tensor) {
// Initialize userbuf communicator // Initialize userbuf communicator
if (!comm_created) { if (!comm_created) {
...@@ -116,9 +119,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -116,9 +119,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
_math_sms -= transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0); _math_sms -= transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);
output_tensor = torch::Tensor(); output_tensor = torch::Tensor();
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); _atomic_gemm = atomic_gemm;
counter = torch::zeros({num_splits * 2}, counter_options); if (_atomic_gemm) {
counter.index_put_({Slice(None, num_splits)}, 1); auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({num_splits * 2}, counter_options);
counter.index_put_({Slice(None, num_splits)}, 1);
}
// CUDA event creation // CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0);
...@@ -519,12 +525,15 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -519,12 +525,15 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
return output_tensor; return output_tensor;
} }
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return false; }
}; // UbufCommOverlap }; // UbufCommOverlap
struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int _tp_id; int _tp_id;
int _tp_size; int _tp_size;
int _ub_reg; int _ub_reg, _ub_reg2;
int _next_rank, _prev_rank, _rank, _rank_round_tp; int _next_rank, _prev_rank, _rank, _rank_round_tp;
int _aggregate2; int _aggregate2;
int _math_sms; int _math_sms;
...@@ -533,18 +542,21 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -533,18 +542,21 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
torch::Tensor _ubuf; torch::Tensor _ubuf;
torch::Tensor counter; torch::Tensor counter;
torch::Tensor _empty_tensor; torch::Tensor _empty_tensor;
torch::Tensor _ubuf_scale_inv;
bool _ubuf_scale_inv_initialized;
std::vector<torch::Tensor> _ubufs; std::vector<torch::Tensor> _ubufs;
at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true); at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true);
at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true); at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute; std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _stop_send, _stop_recv; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv;
int use_ce; int use_ce;
int sms; int sms;
int cga_size; int cga_size;
bool _atomic_gemm;
UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm,
int comm_cga_size, bool set_sm_margin, bool aggregate2, int num_max_streams, int comm_cga_size, bool set_sm_margin, bool aggregate2, int num_max_streams,
torch::Tensor empty_tensor) { bool is_reduce_scatter, bool atomic_gemm, torch::Tensor empty_tensor) {
// Initialize userbuf communicator // Initialize userbuf communicator
if (!comm_created) { if (!comm_created) {
if (rank == 0) { if (rank == 0) {
...@@ -561,16 +573,25 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -561,16 +573,25 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Create workspace tensor with userbuffer // Create workspace tensor with userbuffer
int ubuf_bytes = sample.numel() * sample.element_size(); int ubuf_bytes = sample.numel() * sample.element_size();
int ubuf_chunk_bytes = ubuf_bytes / tp_size; int ubuf_chunk_bytes = ubuf_bytes / tp_size;
int num_ubuf_chunks = tp_size;
if (is_reduce_scatter) {
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
ubuf_bytes = static_cast<int>(ubuf_bytes / tp_size * (tp_size * 2 - 1));
num_ubuf_chunks = static_cast<int>(tp_size * 2 - 1);
}
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes, _ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true); _ub_comm, true);
if (rank == 0) { if (rank == 0) {
printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
} }
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
_ubuf = torch::from_blob(
_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options());
// Create tensor chunks for easy management // Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr()); char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
for (int i = 0; i < tp_size; i++) { for (int i = 0; i < num_ubuf_chunks; i++) {
torch::Tensor ubuf_chunk = torch::from_blob( torch::Tensor ubuf_chunk = torch::from_blob(
ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options()); ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options());
_ubufs.push_back(ubuf_chunk); _ubufs.push_back(ubuf_chunk);
...@@ -599,30 +620,37 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -599,30 +620,37 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
_rank_round_tp = (rank / tp_size) * tp_size; _rank_round_tp = (rank / tp_size) * tp_size;
_next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp; _next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp;
_prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp; _prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp;
_ubuf_scale_inv_initialized = false;
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); _atomic_gemm = atomic_gemm;
counter = torch::zeros({tp_size * 2}, counter_options); if (_atomic_gemm) {
counter.index_put_({Slice(None, tp_size)}, 1); auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
_self_chunk_id = _tp_id; counter = torch::zeros({tp_size * 2}, counter_options);
counter.index_put_({Slice(None, tp_size)}, 1);
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC"); _self_chunk_id = _tp_id;
if (rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') { if (!is_reduce_scatter) {
printf("!!userbuffers_sendrecv_atomic\n"); const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC");
} else if (env_p[0] == '2') { if (rank == 0 && env_p != nullptr) {
printf("!!userbuffers_sendrecv_multiatomic\n"); if (env_p[0] == '1') {
} else if (env_p[0] == '3') { printf("!!userbuffers_sendrecv_atomic\n");
printf("!!userbuffers_sendrecv_multiatomic_shuffle\n"); } else if (env_p[0] == '2') {
_self_chunk_id = 0; printf("!!userbuffers_sendrecv_multiatomic\n");
} else { } else if (env_p[0] == '3') {
printf("!!userbuffers_sendrecv\n"); printf("!!userbuffers_sendrecv_multiatomic_shuffle\n");
_self_chunk_id = 0;
} else {
printf("!!userbuffers_sendrecv\n");
}
}
counter.index_put_({_self_chunk_id}, 0);
} }
} }
counter.index_put_({_self_chunk_id}, 0);
// CUDA event creation // CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_send, 0); cudaEventCreateWithFlags(&_stop_send, 0);
cudaEventCreateWithFlags(&_stop_recv, 0); cudaEventCreateWithFlags(&_stop_recv, 0);
} }
...@@ -758,7 +786,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -758,7 +786,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
return D; return D;
} // split_overlap_ag } // atomic_gemm_overlap_ag
/* /*
** Split AllGather + GEMM using P2P communication ** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
...@@ -948,6 +977,174 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -948,6 +977,174 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
return D; return D;
} // split_overlap_ag } // split_overlap_ag
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
int k = A.size(1);
int n = B.size(0);
// Get communication and GEMM input chunk sizes
int n_chunk = n / _tp_size;
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int input_b_chunk_bytes = n_chunk * k * B.element_size();
// Get input and workspace data pointers
char *input_b_ptr = reinterpret_cast<char *>(B.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
// Catch up the main stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
// Atomic GEMM
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
_ubuf, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, true, counter);
// P2P communication chunk
for (int i = 1; i < _tp_size; i++) {
int send_chunk_id = i - 1;
int recv_chunk_id = send_chunk_id + _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv);
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, send_rank, (cudaStream_t) _stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, recv_rank, (cudaStream_t) _stream_recv);
}
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
int k = A.size(1);
int n = B.size(0);
// Get communication and GEMM input chunk sizes
int n_chunk = n / _tp_size;
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int input_b_chunk_bytes = n_chunk * k * B.element_size();
// Get input and workspace data pointers
char *input_b_ptr = reinterpret_cast<char *>(B.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
// Catch up the main stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_compute[i], _start_compute, 0));
}
// GEMM and send/recv chunks
for (int i = 0; i < _tp_size; i++) {
// GEMM chunk
int input_b_chunk_id = (_tp_id + i + 1) % _tp_size;
char* input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes);
torch::Tensor input_b_chunk = torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options());
// Store the last GEMM chunk output to the recieve buffer.
torch::Tensor workspace_chunk = torch::from_blob(
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
if (i == _tp_size - 1) {
at::cuda::setCurrentCUDAStream(stream_main);
} else {
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
}
te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb,
_ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
if (i > 0) {
// P2P communication chunk
int send_offset = comm_bytes * (i - 1);
int recv_offset = comm_bytes * (i - 1 + _tp_size);
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
CHECK_CUDA(cudaEventRecord(
_start_comm, (cudaStream_t) _stream_compute[(i - 1) % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_send, _start_comm, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, send_rank, (cudaStream_t) _stream_send);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, recv_rank, (cudaStream_t) _stream_recv);
}
}
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr,
_tp_size, _ubufs[0].numel(), (cudaStream_t) stream_main);
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
}
/* /*
** Copy input to _ubufs[0] ** Copy input to _ubufs[0]
*/ */
...@@ -970,6 +1167,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -970,6 +1167,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
(cudaStream_t)stream_main)); (cudaStream_t)stream_main));
} }
} }
torch::Tensor get_ubuf_output(int comm_type) { torch::Tensor get_ubuf_output(int comm_type) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr()); char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type); COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
...@@ -981,6 +1179,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -981,6 +1179,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int output_c_dim1 = _ubuf.size(1); int output_c_dim1 = _ubuf.size(1);
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
} }
void set_ubuf_scale_inv(const torch::Tensor &scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); }
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return true; }
}; // UbufP2PCommOverlap }; // UbufP2PCommOverlap
} // namespace ubuf } // namespace ubuf
...@@ -109,26 +109,36 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -109,26 +109,36 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG) .value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P)
.value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P)
.value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS) .value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS)
.value("ATOMIC_GEMM_AG", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG); .value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P)
.value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap") py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int, torch::Tensor>()) .def(py::init<torch::Tensor&, int, int, int, int, int, bool, int, bool, torch::Tensor>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv) .def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv)
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs) .def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs)
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf) .def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output); .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output)
.def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm)
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap") py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, bool, bool, int, torch::Tensor>()) .def(py::init<torch::Tensor&, int, int, int, int, bool, bool, int, bool, bool, torch::Tensor>())
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("atomic_gemm_overlap_ag", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag) .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs)
.def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag)
.def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output)
.def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf)
.def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm)
.def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap)
.def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv);
#else // NVTE_WITH_USERBUFFERS #else // NVTE_WITH_USERBUFFERS
m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations"); m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations");
m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations"); m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations");
......
...@@ -3666,3 +3666,34 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { ...@@ -3666,3 +3666,34 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 grid(1); dim3 grid(1);
consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i); consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
} }
template <typename fp8type>
__global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale,
const int num_inputs, const int input_size) {
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
fp8type *inputs_fp8 = reinterpret_cast<fp8type *>(inputs);
float accum_buf = static_cast<float>(inputs_fp8[tid]) * (*scale);
#pragma unroll
for (int i = 1; i < num_inputs; i++) {
accum_buf += static_cast<float>(inputs_fp8[tid + input_size * i]) * (*scale);
}
half *output_half = reinterpret_cast<half *>(output);
output_half[tid] = (half) accum_buf;
}
template <typename fp8type>
void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs,
int input_size, cudaStream_t stream) {
size_t num_threads = MAX_THREADS / 4;
size_t num_blocks = (input_size +num_threads - 1) / num_threads;
dim3 block(num_threads);
dim3 grid(num_blocks);
reduce_fp8_in_bf16_out_cuda<fp8type><<<grid, block, 0, stream>>>(
inputs, output, scale, num_inputs, input_size);
}
template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(
void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream);
template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(
void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream);
...@@ -305,4 +305,8 @@ void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream = 0); ...@@ -305,4 +305,8 @@ void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream = 0);
void destroy_communicator(communicator *comm); void destroy_communicator(communicator *comm);
template <typename fp8type>
void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs,
int input_size, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
...@@ -129,13 +129,14 @@ def initialize_ub( ...@@ -129,13 +129,14 @@ def initialize_ub(
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
] ]
if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))): if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))):
fp8_buf.append ("proj_fprop") fp8_buf += ["proj_fprop", "fc2_fprop"]
# Default overlap methods for layers # Default overlap methods for layers
methods = { methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline":["proj_fprop", "fc2_fprop"], "pipeline":["proj_fprop", "fc2_fprop"],
"bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
} }
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
def get_method(name): def get_method(name):
for method, names in methods.items(): for method, names in methods.items():
...@@ -151,7 +152,28 @@ def initialize_ub( ...@@ -151,7 +152,28 @@ def initialize_ub(
set_sm_margin: int = 0, set_sm_margin: int = 0,
num_splits: int = 4, num_splits: int = 4,
aggregate: int = 0, aggregate: int = 0,
atomic_gemm: int = 0,
is_reduce_scatter: int = 0,
) -> None: ) -> None:
if atomic_gemm:
warnings.warn(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
if is_reduce_scatter and method == "ring_exchange":
raise ValueError(
"Atomic GEMM is not supported for ReduceScatter with `ring_exchange` method."
)
if method == 'bulk':
warnings.warn(
"Atoimic GEMM not is supported for a bulk overlap."
"Defaulting to `atomic_gemm=False`."
)
atomic_gemm = 0
if not is_reduce_scatter and method == 'pipeline':
raise ValueError(
"`pipeline` overlap method is not supported for AllGather."
)
sample_buffer = torch.empty( sample_buffer = torch.empty(
shape, shape,
dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype, dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype,
...@@ -166,6 +188,8 @@ def initialize_ub( ...@@ -166,6 +188,8 @@ def initialize_ub(
set_sm_margin, # Set SM margin set_sm_margin, # Set SM margin
aggregate, # Aggregate 2X GEMM chunks aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
is_reduce_scatter, # overlap with reduce scatter
atomic_gemm, # use a single GEMM with atomic-counters
torch.Tensor(), # empty tensor to pass to counters torch.Tensor(), # empty tensor to pass to counters
) )
else: else:
...@@ -178,6 +202,7 @@ def initialize_ub( ...@@ -178,6 +202,7 @@ def initialize_ub(
num_splits, # Number of communication splits num_splits, # Number of communication splits
set_sm_margin, # Set SM margin set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
atomic_gemm, # use a single GEMM with atomic-counters
torch.Tensor(), # empty tensor to pass to counters torch.Tensor(), # empty tensor to pass to counters
) )
_ub_communicators[name] = ub_obj _ub_communicators[name] = ub_obj
...@@ -191,6 +216,8 @@ def initialize_ub( ...@@ -191,6 +216,8 @@ def initialize_ub(
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0 num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0 set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0 aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0
is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0
add_ub( add_ub(
name, name,
method, method,
...@@ -198,7 +225,9 @@ def initialize_ub( ...@@ -198,7 +225,9 @@ def initialize_ub(
cga_size, cga_size,
set_sm_margin, set_sm_margin,
num_splits, num_splits,
aggregate aggregate,
atomic_gemm,
is_reduce_scatter,
) )
else: else:
method = get_method(name) method = get_method(name)
...@@ -632,12 +661,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -632,12 +661,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_mat = grad_output.view((-1, grad_output.shape[-1])) grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
gather_grad_output = row_parallel_mode and ctx.sequence_parallel gather_grad_output = row_parallel_mode and ctx.sequence_parallel
if gather_grad_output:
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
# No-FP8 case: bgrad is fused with wgrad for this case. # No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8: if not ctx.fp8:
if gather_grad_output: if gather_grad_output:
if not ub_overlap_ag: if not ctx.ub_overlap_ag:
grad_output_mat, _ = gather_along_first_dim( grad_output_mat, _ = gather_along_first_dim(
grad_output_mat, ctx.tp_group grad_output_mat, ctx.tp_group
) )
...@@ -656,7 +683,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -656,7 +683,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
and ctx.fp8_meta["recipe"].override_linear_precision.wgrad and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
): ):
assert ( assert (
not ub_overlap_ag not ctx.ub_overlap_ag
), "override_linear_precision.wgrad not supported with UB AG overlap" ), "override_linear_precision.wgrad not supported with UB AG overlap"
grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
...@@ -665,7 +692,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -665,7 +692,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias = grad_output_mat.sum(dim=0) grad_bias = grad_output_mat.sum(dim=0)
else: else:
grad_bias = None grad_bias = None
if ub_overlap_ag: if ctx.ub_overlap_ag:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else: else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
...@@ -676,7 +703,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -676,7 +703,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_dtype_backward, fp8_dtype_backward,
out=grad_output_c, out=grad_output_c,
) )
if not ub_overlap_ag: if not ctx.ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else: else:
......
...@@ -86,8 +86,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -86,8 +86,7 @@ class _LayerNormLinear(torch.autograd.Function):
primary_weights_in_fp8: bool, primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_split_ag: bool, ub_overlap_ag: bool,
ub_atomic_gemm_ag: bool,
ub_name: str, ub_name: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
...@@ -106,12 +105,11 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -106,12 +105,11 @@ class _LayerNormLinear(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag or ub_atomic_gemm_ag: if ub_overlap_ag:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False ub_overlap_ag = False
ub_atomic_gemm_ag = False if ub_overlap_ag:
if ub_split_ag or ub_atomic_gemm_ag:
dim_size = list(inputmat.size()) dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub(ub_name+"_fprop") ub_obj_lnout = get_ub(ub_name+"_fprop")
...@@ -119,8 +117,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -119,8 +117,6 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -138,9 +134,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -138,9 +134,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Column Parallel Linear # Column Parallel Linear
ln_out_gathered = False ln_out_gathered = False
if ub_split_ag or ub_atomic_gemm_ag: if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out) ln_out = torch.empty_like(ln_out)
if ub_obj_lnout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif parallel_mode == "column" and sequence_parallel: elif parallel_mode == "column" and sequence_parallel:
ln_out_gathered = True ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
...@@ -201,8 +201,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -201,8 +201,6 @@ class _LayerNormLinear(torch.autograd.Function):
) )
weight_t_fp8 = None weight_t_fp8 = None
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
out, _ = tex.fp8_gemm( out, _ = tex.fp8_gemm(
weight_fp8._data, weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -217,9 +215,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -217,9 +215,9 @@ class _LayerNormLinear(torch.autograd.Function):
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
ub_algo=ub_algo, ub_algo=ub_algo if ub_overlap_ag else None,
ub=ub_obj_lnout if (ub_split_ag or ub_atomic_gemm_ag) else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if (ub_split_ag or ub_atomic_gemm_ag) else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
) )
else: else:
# Cast for native AMP # Cast for native AMP
...@@ -243,9 +241,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -243,9 +241,9 @@ class _LayerNormLinear(torch.autograd.Function):
get_workspace(), get_workspace(),
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_split_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
) )
if is_grad_enabled: if is_grad_enabled:
...@@ -624,7 +622,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -624,7 +622,6 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -737,8 +734,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -737,8 +734,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False, ub_overlap_ag: bool = False,
ub_atomic_gemm_ag: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -758,23 +754,16 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -758,23 +754,16 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag self.ub_overlap_ag = ub_overlap_ag
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]):
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_split_ag]):
assert ub_name is not None, "Userbuffer name [string] is not set." assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name self.ub_name = ub_name
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]):
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag:
assert ( assert (
tex.userbuf_comm_available() tex.userbuf_comm_available()
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
if ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -1098,8 +1087,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1098,8 +1087,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.primary_weights_in_fp8, self.primary_weights_in_fp8,
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_split_ag, self.ub_overlap_ag,
self.ub_atomic_gemm_ag,
self.ub_name, self.ub_name,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -117,10 +117,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -117,10 +117,8 @@ class _LayerNormMLP(torch.autograd.Function):
primary_weights_in_fp8: bool, primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_split_rs: bool, ub_overlap_rs: bool,
ub_atomic_gemm_rs: bool, ub_overlap_ag: bool,
ub_split_ag: bool,
ub_atomic_gemm_ag: bool,
gemm_gelu_fusion: bool, gemm_gelu_fusion: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
...@@ -142,25 +140,17 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -142,25 +140,17 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag or ub_atomic_gemm_ag: tp_world_size = get_distributed_world_size(tp_group)
tp_world_size = get_distributed_world_size(tp_group) if ub_overlap_ag:
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False ub_overlap_ag = False
ub_atomic_gemm_ag = False
ub_overlap_ag = ub_split_ag or ub_atomic_gemm_ag
if ub_overlap_ag: if ub_overlap_ag:
ub_obj_lnout = get_ub("fc1_fprop") ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0) ln_out = ub_obj_lnout.get_ubuf_output(0)
else: else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_split_rs or ub_atomic_gemm_rs: ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
ub_atomic_gemm_rs = False
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
...@@ -181,6 +171,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -181,6 +171,10 @@ class _LayerNormMLP(torch.autograd.Function):
if ub_overlap_ag: if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out) ln_out = torch.empty_like(ln_out)
if ub_obj_lnout.is_atomic_gemm():
ub_algo_ag = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo_ag = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif set_parallel_mode and sequence_parallel: elif set_parallel_mode and sequence_parallel:
ln_out_gathered = True ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
...@@ -267,9 +261,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -267,9 +261,6 @@ class _LayerNormMLP(torch.autograd.Function):
) )
fc2_weight_t_fp8 = None fc2_weight_t_fp8 = None
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
# Perform FP8 GEMM # Perform FP8 GEMM
fp8_gemm_args = [ fp8_gemm_args = [
fc1_weight_fp8._data, fc1_weight_fp8._data,
...@@ -287,7 +278,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -287,7 +278,7 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias, bias=fc1_bias,
use_bias=use_fc1_bias, use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
ub_algo=ub_algo, ub_algo=ub_algo_ag if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
) )
...@@ -321,13 +312,23 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -321,13 +312,23 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = (
None, None, None, activation_dtype) None, None, None, activation_dtype)
if ub_split_rs or ub_atomic_gemm_rs: if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop") ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1) fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size()) dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
if ub_obj_fc2out.is_p2p_overlap():
if ub_obj_fc2out.is_atomic_gemm():
ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
else:
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
if ub_obj_fc2out.is_atomic_gemm():
ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
if ub_obj_fc2out.is_fp8_ubuf(): if ub_obj_fc2out.is_fp8_ubuf():
fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT
...@@ -340,8 +341,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -340,8 +341,6 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = tex.fp8_gemm( _ = tex.fp8_gemm(
fc2_weight_fp8._data, fc2_weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -357,9 +356,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -357,9 +356,9 @@ class _LayerNormMLP(torch.autograd.Function):
use_bias=use_fc2_bias, use_bias=use_fc2_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
out=fc2_out, out=fc2_out,
ub_algo=ub_algo, ub_algo=ub_algo_rs if ub_overlap_rs else None,
ub=ub_obj_fc2out if ub_split_rs or ub_atomic_gemm_rs else None, ub=ub_obj_fc2out if ub_overlap_rs else None,
extra_output_tensor=rs_out if ub_split_rs or ub_atomic_gemm_rs else None, extra_output_tensor=rs_out if ub_overlap_rs else None,
out_index=fc2_out_index, out_index=fc2_out_index,
fp8_meta_tensor = fc2_meta_tensor, fp8_meta_tensor = fc2_meta_tensor,
D_dtype = fc2_te_type, D_dtype = fc2_te_type,
...@@ -395,9 +394,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -395,9 +394,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias, bias=fc1_bias,
use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias,
gelu=not bias_gelu_nvfusion and (activation == 'gelu'), gelu=not bias_gelu_nvfusion and (activation == 'gelu'),
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_split_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None,
) )
if not is_grad_enabled: if not is_grad_enabled:
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
...@@ -427,13 +426,17 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -427,13 +426,17 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
torch.max(-amin, amax).float() torch.max(-amin, amax).float()
if ub_split_rs: if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop") ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1) fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size()) dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
if ub_obj_fc2out.is_p2p_overlap():
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
else: else:
dim_size = list(gelu_out.size()) dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0) dim_size[1] = fc2_weight.size(0)
...@@ -446,9 +449,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -446,9 +449,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc2_bias, bias=fc2_bias,
use_bias=use_fc2_bias, use_bias=use_fc2_bias,
out=fc2_out, out=fc2_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, ub_algo=ub_algo_rs if ub_overlap_rs else None,
ub=ub_obj_fc2out if ub_split_rs else None, ub=ub_obj_fc2out if ub_overlap_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None, extra_output_tensor=rs_out if ub_overlap_rs else None,
) )
if not is_grad_enabled: if not is_grad_enabled:
clear_tensor_data(gelu_out) clear_tensor_data(gelu_out)
...@@ -515,13 +518,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -515,13 +518,12 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.zero_centered_gamma = zero_centered_gamma ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_split_ag = ub_split_ag ctx.ub_overlap_ag = ub_overlap_ag
ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization ctx.normalization = normalization
# Row Parallel Linear # Row Parallel Linear
if ub_split_rs or ub_atomic_gemm_rs: if ub_overlap_rs:
fc2_out = rs_out fc2_out = rs_out
elif set_parallel_mode and sequence_parallel: elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
...@@ -590,18 +592,19 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -590,18 +592,19 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("fc1_dgrad") ub_obj_lnout = get_ub("fc1_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag if ctx.ub_overlap_ag:
if ub_overlap_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group) tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1: if tp_world_size == 1:
ctx.ub_split_ag = False
ctx.ub_overlap_ag = False ctx.ub_overlap_ag = False
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
if ub_overlap_ag: if ctx.ub_overlap_ag:
dim_size = list(grad_outputs[0].size()) dim_size = list(grad_outputs[0].size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("fc2_dgrad") ctx.ub_obj_gradout = get_ub("fc2_dgrad")
if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess
( (
...@@ -645,8 +648,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -645,8 +648,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
# FC2 DGRAD; Unconditional # FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm( fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8._data, fc2_weight_t_fp8._data,
...@@ -660,10 +661,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -660,10 +661,10 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=ub_algo, ub_algo=ub_algo if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ub_overlap_ag else None, ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
) )
if ub_overlap_ag: if ctx.ub_overlap_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
clear_tensor_data(grad_output_c) clear_tensor_data(grad_output_c)
...@@ -801,8 +802,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -801,8 +802,9 @@ class _LayerNormMLP(torch.autograd.Function):
gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == 'gelu'), gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == 'gelu'),
grad=True, grad=True,
gelu_input=fc1_out, gelu_input=fc1_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
) )
# FC2 WGRAD # FC2 WGRAD
...@@ -1070,8 +1072,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1070,8 +1072,6 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -1194,10 +1194,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1194,10 +1194,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False, ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False, ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False, ub_overlap_rs: bool = False,
ub_atomic_gemm_rs: bool = False, ub_overlap_ag: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_ag: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1218,29 +1216,18 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1218,29 +1216,18 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_rs = ub_split_rs self.ub_overlap_rs = ub_overlap_rs
self.ub_split_ag = ub_split_ag self.ub_overlap_ag = ub_overlap_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
self.gemm_gelu_fusion = (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and self.gemm_gelu_fusion = \
self.activation == 'gelu' and self.ub_split_ag) (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and
self.activation == 'gelu' and not get_ub("fc1_fprop").is_atomic_gemm())
if (ub_bulk_wgrad # pylint: disable=too-many-boolean-expressions
or ub_bulk_dgrad if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag]):
or ub_split_rs
or ub_split_ag
or ub_atomic_gemm_rs
or ub_atomic_gemm_ag):
assert ( assert (
tex.userbuf_comm_available() tex.userbuf_comm_available()
), "Userbuffer communication backend not available." ), "Userbuffer communication backend not available."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -1490,10 +1477,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1490,10 +1477,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.primary_weights_in_fp8, self.primary_weights_in_fp8,
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_split_rs, self.ub_overlap_rs,
self.ub_atomic_gemm_rs, self.ub_overlap_ag,
self.ub_split_ag,
self.ub_atomic_gemm_ag,
self.gemm_gelu_fusion, self.gemm_gelu_fusion,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Linear API""" """Linear API"""
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch import torch
...@@ -79,10 +78,8 @@ class _Linear(torch.autograd.Function): ...@@ -79,10 +78,8 @@ class _Linear(torch.autograd.Function):
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
is_grad_enabled: bool, is_grad_enabled: bool,
primary_weights_in_fp8: bool, primary_weights_in_fp8: bool,
ub_split_rs: bool, ub_overlap_rs: bool,
ub_split_ag: bool, ub_overlap_ag: bool,
ub_atomic_gemm_rs: bool,
ub_atomic_gemm_ag: bool,
ub_name: str ub_name: str
) -> torch.Tensor: ) -> torch.Tensor:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
...@@ -94,14 +91,8 @@ class _Linear(torch.autograd.Function): ...@@ -94,14 +91,8 @@ class _Linear(torch.autograd.Function):
assert_dim_for_fp8_exec(weight) assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch update_fp8_weights = is_first_microbatch is None or is_first_microbatch
tp_world_size = get_distributed_world_size(tp_group)
if ub_split_rs or ub_atomic_gemm_rs: ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
ub_atomic_gemm_rs = False
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
# Cast input to expected dtype # Cast input to expected dtype
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
...@@ -180,14 +171,23 @@ class _Linear(torch.autograd.Function): ...@@ -180,14 +171,23 @@ class _Linear(torch.autograd.Function):
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
None, None, None, activation_dtype) None, None, None, activation_dtype)
if ub_split_rs or ub_atomic_gemm_rs: if ub_overlap_rs:
ub_obj_projout = get_ub(ub_name+"_fprop") ub_obj_projout = get_ub(ub_name+"_fprop")
out = ub_obj_projout.get_ubuf_output(1) out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0) dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap():
if ub_obj_projout.is_atomic_gemm():
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
if ub_obj_projout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
if ub_obj_projout.is_fp8_ubuf(): if ub_obj_projout.is_fp8_ubuf():
proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
meta_tensor = fp8_meta["scaling_fwd"] meta_tensor = fp8_meta["scaling_fwd"]
...@@ -199,8 +199,6 @@ class _Linear(torch.autograd.Function): ...@@ -199,8 +199,6 @@ class _Linear(torch.autograd.Function):
dim_size[1] = weight.size(0) dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = fp8_gemm( _ = fp8_gemm(
weight_fp8._data, weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
...@@ -216,9 +214,9 @@ class _Linear(torch.autograd.Function): ...@@ -216,9 +214,9 @@ class _Linear(torch.autograd.Function):
use_bias=use_bias, use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP, use_split_accumulator=_2X_ACC_FPROP,
out=out, out=out,
ub_algo=ub_algo, ub_algo=ub_algo if ub_overlap_rs else None,
ub=ub_obj_projout if (ub_split_rs or ub_atomic_gemm_rs) else None, ub=ub_obj_projout if ub_overlap_rs else None,
extra_output_tensor=rs_out if (ub_split_rs or ub_atomic_gemm_rs) else None, extra_output_tensor=rs_out if ub_overlap_rs else None,
out_index=proj_out_index, out_index=proj_out_index,
fp8_meta_tensor = meta_tensor, fp8_meta_tensor = meta_tensor,
D_dtype = proj_out_tetype, D_dtype = proj_out_tetype,
...@@ -238,13 +236,17 @@ class _Linear(torch.autograd.Function): ...@@ -238,13 +236,17 @@ class _Linear(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.max(-amin, amax).float() torch.max(-amin, amax).float()
if ub_split_rs: if ub_overlap_rs:
ub_obj_projout = get_ub("proj_fprop") ub_obj_projout = get_ub(ub_name+"_fprop")
out = ub_obj_projout.get_ubuf_output(1) out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group)
dim_size[1] = weight.size(0) dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap():
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
else: else:
dim_size = list(inputmat_total.size()) dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0) dim_size[1] = weight.size(0)
...@@ -258,9 +260,9 @@ class _Linear(torch.autograd.Function): ...@@ -258,9 +260,9 @@ class _Linear(torch.autograd.Function):
bias=bias, bias=bias,
use_bias=use_bias, use_bias=use_bias,
out=out, out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, ub_algo=ub_algo if ub_overlap_rs else None,
ub=ub_obj_projout if ub_split_rs else None, ub=ub_obj_projout if ub_overlap_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None, extra_output_tensor=rs_out if ub_overlap_rs else None,
) )
if is_grad_enabled: if is_grad_enabled:
...@@ -307,14 +309,13 @@ class _Linear(torch.autograd.Function): ...@@ -307,14 +309,13 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.ub_split_ag = ub_split_ag ctx.ub_overlap_ag = ub_overlap_ag
ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
ctx.ub_name = ub_name ctx.ub_name = ub_name
ctx.tp_size = tp_size ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear # Row Parallel Linear
if ub_split_rs or ub_atomic_gemm_rs: if ub_overlap_rs:
out = rs_out out = rs_out
elif parallel_mode == "row" and sequence_parallel: elif parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group) out, _ = reduce_scatter_along_first_dim(out, tp_group)
...@@ -350,16 +351,16 @@ class _Linear(torch.autograd.Function): ...@@ -350,16 +351,16 @@ class _Linear(torch.autograd.Function):
weight_t_fp8 = weight.transpose( weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
) )
tp_world_size = get_distributed_world_size(ctx.tp_group)
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
tp_world_size = get_distributed_world_size(ctx.tp_group) if ctx.ub_overlap_ag:
if tp_world_size == 1:
ctx.ub_split_ag = False
ctx.ub_atomic_gemm_ag = False
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
dim_size = list(grad_output.size()) dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad")
if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
( (
grad_output, grad_output,
grad_output_c, grad_output_c,
...@@ -397,8 +398,6 @@ class _Linear(torch.autograd.Function): ...@@ -397,8 +398,6 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False ctx.fp8_meta["recipe"], fprop_tensor=False
) )
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
if ctx.requires_dgrad: if ctx.requires_dgrad:
if ctx.fp8: if ctx.fp8:
dgrad, _ = fp8_gemm( dgrad, _ = fp8_gemm(
...@@ -413,8 +412,8 @@ class _Linear(torch.autograd.Function): ...@@ -413,8 +412,8 @@ class _Linear(torch.autograd.Function):
ctx.activation_dtype, ctx.activation_dtype,
get_workspace(), get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=ub_algo, ub_algo=ub_algo if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag else None, ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
) )
else: else:
dgrad, _, _ = gemm( dgrad, _, _ = gemm(
...@@ -424,8 +423,9 @@ class _Linear(torch.autograd.Function): ...@@ -424,8 +423,9 @@ class _Linear(torch.autograd.Function):
get_workspace(), get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
) )
# Overlap dgrad-RS/AR with wgrad # Overlap dgrad-RS/AR with wgrad
...@@ -442,7 +442,7 @@ class _Linear(torch.autograd.Function): ...@@ -442,7 +442,7 @@ class _Linear(torch.autograd.Function):
if ctx.fp8: if ctx.fp8:
# WGRAD # WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: if ctx.ub_overlap_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
if inputmat_t_total is None: if inputmat_t_total is None:
inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward) inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward)
...@@ -542,8 +542,6 @@ class _Linear(torch.autograd.Function): ...@@ -542,8 +542,6 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
None,
) )
...@@ -629,10 +627,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -629,10 +627,8 @@ class Linear(TransformerEngineBaseModule):
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_split_rs: bool = False, ub_overlap_rs: bool = False,
ub_split_ag: bool = False, ub_overlap_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -645,28 +641,18 @@ class Linear(TransformerEngineBaseModule): ...@@ -645,28 +641,18 @@ class Linear(TransformerEngineBaseModule):
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias self.apply_bias = bias and not return_bias
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_split_rs = ub_split_rs self.ub_overlap_rs = ub_overlap_rs
self.ub_split_ag = ub_split_ag self.ub_overlap_ag = ub_overlap_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs if ub_overlap_rs or ub_overlap_ag:
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]):
assert ub_name is not None, "Userbuffer name [string] is not set." assert ub_name is not None, "Userbuffer name [string] is not set."
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
self.ub_name = ub_name self.ub_name = ub_name
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
if device == 'meta': if device == 'meta':
assert parameters_split is None, ("Cannot split module parameters " assert parameters_split is None, ("Cannot split module parameters "
"on 'meta' device.") "on 'meta' device.")
if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
if tp_size == 1: if tp_size == 1:
...@@ -930,10 +916,8 @@ class Linear(TransformerEngineBaseModule): ...@@ -930,10 +916,8 @@ class Linear(TransformerEngineBaseModule):
self.parallel_mode, self.parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.primary_weights_in_fp8, self.primary_weights_in_fp8,
self.ub_split_rs, self.ub_overlap_rs,
self.ub_split_ag, self.ub_overlap_ag,
self.ub_atomic_gemm_rs,
self.ub_atomic_gemm_ag,
self.ub_name, self.ub_name,
) )
out = linear_fn(*args) out = linear_fn(*args)
......
...@@ -259,10 +259,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -259,10 +259,8 @@ class TransformerLayer(torch.nn.Module):
ub_tp_comm_overlap: bool = False, ub_tp_comm_overlap: bool = False,
ub_bulk_wgrad: bool = True, ub_bulk_wgrad: bool = True,
ub_bulk_dgrad: bool = True, ub_bulk_dgrad: bool = True,
ub_split_ag: bool = True, ub_overlap_ag: bool = True,
ub_split_rs: bool = True, ub_overlap_rs: bool = True,
ub_atomic_gemm_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
bias: bool = True, bias: bool = True,
activation: str = 'gelu', activation: str = 'gelu',
normalization: str = "LayerNorm", normalization: str = "LayerNorm",
...@@ -282,21 +280,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -282,21 +280,8 @@ class TransformerLayer(torch.nn.Module):
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
ub_split_ag = ub_tp_comm_overlap and ub_split_ag ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag
ub_split_rs = ub_tp_comm_overlap and ub_split_rs ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs
ub_atomic_gemm_rs = ub_tp_comm_overlap and ub_atomic_gemm_rs
assert (
not (ub_split_rs and ub_atomic_gemm_rs)
), "Only one type of RS overlap ub_split_rs/ub_atomic_gemm_rs should be enabled."
ub_atomic_gemm_ag = ub_tp_comm_overlap and ub_atomic_gemm_ag
assert (
not (ub_split_ag and ub_atomic_gemm_ag)
), "Only one type of AG overlap ub_split_ag/ub_atomic_gemm_ag should be enabled."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number self.layer_number = layer_number
...@@ -370,10 +355,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -370,10 +355,8 @@ class TransformerLayer(torch.nn.Module):
"qkv_weight_interleaved" : qkv_weight_interleaved, "qkv_weight_interleaved" : qkv_weight_interleaved,
"ub_bulk_wgrad" : ub_bulk_wgrad, "ub_bulk_wgrad" : ub_bulk_wgrad,
"ub_bulk_dgrad" : ub_bulk_dgrad, "ub_bulk_dgrad" : ub_bulk_dgrad,
"ub_split_ag" : ub_split_ag, "ub_overlap_ag" : ub_overlap_ag,
"ub_split_rs" : ub_split_rs, "ub_overlap_rs" : ub_overlap_rs,
"ub_atomic_gemm_rs" : ub_atomic_gemm_rs,
"ub_atomic_gemm_ag" : ub_atomic_gemm_ag,
"qkv_format" : self.attn_input_format, "qkv_format" : self.attn_input_format,
} }
...@@ -427,10 +410,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -427,10 +410,8 @@ class TransformerLayer(torch.nn.Module):
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad, ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_rs=ub_split_rs, ub_overlap_rs=ub_overlap_rs,
ub_split_ag=ub_split_ag, ub_overlap_ag=ub_overlap_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
activation=activation, activation=activation,
normalization=normalization, normalization=normalization,
device=device, device=device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment