Unverified Commit b855656b authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

TP-RS overlap with send/recv ring-exchange (#724)



* TP-RS overlap with send/recv

Atomic GEMM based TP-RS overlap with send/recv
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

Specify userbuffer overlap method of each overlap instance
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

P2P TP-RS overlap with fp8 GEMM outputs
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

Fix TP-RS overlap with send/recv
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* linting
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix typo
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

---------
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 59bfc17b
......@@ -3175,10 +3175,8 @@ class MultiheadAttention(torch.nn.Module):
qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
bias: bool = True,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
......@@ -3265,9 +3263,8 @@ class MultiheadAttention(torch.nn.Module):
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv",
**common_gemm_kwargs,
)
......@@ -3297,9 +3294,8 @@ class MultiheadAttention(torch.nn.Module):
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv",
**common_gemm_kwargs,
)
......@@ -3347,10 +3343,8 @@ class MultiheadAttention(torch.nn.Module):
bias=bias,
return_bias=return_bias,
parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag,
ub_name="proj",
**common_gemm_kwargs,
)
......
......@@ -101,14 +101,14 @@ def fp8_gemm(
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (0, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG:
fn = ub.atomic_gemm_overlap_ag
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
fn = ub.atomic_gemm_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
......@@ -119,12 +119,24 @@ def fp8_gemm(
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P:
fn = ub.atomic_gemm_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
_ = fn(*args)
return out, gelu_input
......@@ -217,8 +229,8 @@ def gemm(
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
......@@ -229,6 +241,12 @@ def gemm(
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (False, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
_ = fn(*args)
return out, grad_bias, gelu_input
......@@ -41,10 +41,12 @@ enum class COMM_TYPE { RS = 0, AG = 1 };
enum class UBOverlapAlgo {
BULK_OVERLAP_AG = 0,
BULK_OVERLAP_RS = 1,
SPLIT_PIPELINED_AG = 2,
SPLIT_PIPELINED_AG_P2P = 2,
SPLIT_PIPELINED_RS = 3,
ATOMIC_GEMM_RS = 4,
ATOMIC_GEMM_AG = 5
SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7
};
struct UbufBase {
......@@ -70,9 +72,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int comm_sms;
int cga_size;
int use_ce;
bool _atomic_gemm;
UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size,
int num_splits, bool set_sm_margin, int num_max_streams,
int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm,
torch::Tensor empty_tensor) {
// Initialize userbuf communicator
if (!comm_created) {
......@@ -116,9 +119,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
_math_sms -= transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);
output_tensor = torch::Tensor();
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({num_splits * 2}, counter_options);
counter.index_put_({Slice(None, num_splits)}, 1);
_atomic_gemm = atomic_gemm;
if (_atomic_gemm) {
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({num_splits * 2}, counter_options);
counter.index_put_({Slice(None, num_splits)}, 1);
}
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
......@@ -519,12 +525,15 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
return output_tensor;
}
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return false; }
}; // UbufCommOverlap
struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int _tp_id;
int _tp_size;
int _ub_reg;
int _ub_reg, _ub_reg2;
int _next_rank, _prev_rank, _rank, _rank_round_tp;
int _aggregate2;
int _math_sms;
......@@ -533,18 +542,21 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
torch::Tensor _ubuf;
torch::Tensor counter;
torch::Tensor _empty_tensor;
torch::Tensor _ubuf_scale_inv;
bool _ubuf_scale_inv_initialized;
std::vector<torch::Tensor> _ubufs;
at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true);
at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _stop_send, _stop_recv;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv;
int use_ce;
int sms;
int cga_size;
bool _atomic_gemm;
UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm,
int comm_cga_size, bool set_sm_margin, bool aggregate2, int num_max_streams,
torch::Tensor empty_tensor) {
bool is_reduce_scatter, bool atomic_gemm, torch::Tensor empty_tensor) {
// Initialize userbuf communicator
if (!comm_created) {
if (rank == 0) {
......@@ -561,16 +573,25 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Create workspace tensor with userbuffer
int ubuf_bytes = sample.numel() * sample.element_size();
int ubuf_chunk_bytes = ubuf_bytes / tp_size;
int num_ubuf_chunks = tp_size;
if (is_reduce_scatter) {
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
ubuf_bytes = static_cast<int>(ubuf_bytes / tp_size * (tp_size * 2 - 1));
num_ubuf_chunks = static_cast<int>(tp_size * 2 - 1);
}
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
if (rank == 0) {
printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
}
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
_ubuf = torch::from_blob(
_ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options());
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
for (int i = 0; i < tp_size; i++) {
for (int i = 0; i < num_ubuf_chunks; i++) {
torch::Tensor ubuf_chunk = torch::from_blob(
ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options());
_ubufs.push_back(ubuf_chunk);
......@@ -599,30 +620,37 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
_rank_round_tp = (rank / tp_size) * tp_size;
_next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp;
_prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp;
_ubuf_scale_inv_initialized = false;
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({tp_size * 2}, counter_options);
counter.index_put_({Slice(None, tp_size)}, 1);
_self_chunk_id = _tp_id;
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC");
if (rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') {
printf("!!userbuffers_sendrecv_atomic\n");
} else if (env_p[0] == '2') {
printf("!!userbuffers_sendrecv_multiatomic\n");
} else if (env_p[0] == '3') {
printf("!!userbuffers_sendrecv_multiatomic_shuffle\n");
_self_chunk_id = 0;
} else {
printf("!!userbuffers_sendrecv\n");
_atomic_gemm = atomic_gemm;
if (_atomic_gemm) {
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({tp_size * 2}, counter_options);
counter.index_put_({Slice(None, tp_size)}, 1);
_self_chunk_id = _tp_id;
if (!is_reduce_scatter) {
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC");
if (rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') {
printf("!!userbuffers_sendrecv_atomic\n");
} else if (env_p[0] == '2') {
printf("!!userbuffers_sendrecv_multiatomic\n");
} else if (env_p[0] == '3') {
printf("!!userbuffers_sendrecv_multiatomic_shuffle\n");
_self_chunk_id = 0;
} else {
printf("!!userbuffers_sendrecv\n");
}
}
counter.index_put_({_self_chunk_id}, 0);
}
}
counter.index_put_({_self_chunk_id}, 0);
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_send, 0);
cudaEventCreateWithFlags(&_stop_recv, 0);
}
......@@ -758,7 +786,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
return D;
} // split_overlap_ag
} // atomic_gemm_overlap_ag
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
......@@ -948,6 +977,174 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
return D;
} // split_overlap_ag
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
int k = A.size(1);
int n = B.size(0);
// Get communication and GEMM input chunk sizes
int n_chunk = n / _tp_size;
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int input_b_chunk_bytes = n_chunk * k * B.element_size();
// Get input and workspace data pointers
char *input_b_ptr = reinterpret_cast<char *>(B.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
// Catch up the main stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
// Atomic GEMM
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
_ubuf, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, true, counter);
// P2P communication chunk
for (int i = 1; i < _tp_size; i++) {
int send_chunk_id = i - 1;
int recv_chunk_id = send_chunk_id + _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv);
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, send_rank, (cudaStream_t) _stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, recv_rank, (cudaStream_t) _stream_recv);
}
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
at::Tensor rs_output) {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
int k = A.size(1);
int n = B.size(0);
// Get communication and GEMM input chunk sizes
int n_chunk = n / _tp_size;
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int input_b_chunk_bytes = n_chunk * k * B.element_size();
// Get input and workspace data pointers
char *input_b_ptr = reinterpret_cast<char *>(B.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];
// Catch up the main stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_compute[i], _start_compute, 0));
}
// GEMM and send/recv chunks
for (int i = 0; i < _tp_size; i++) {
// GEMM chunk
int input_b_chunk_id = (_tp_id + i + 1) % _tp_size;
char* input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes);
torch::Tensor input_b_chunk = torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options());
// Store the last GEMM chunk output to the recieve buffer.
torch::Tensor workspace_chunk = torch::from_blob(
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
if (i == _tp_size - 1) {
at::cuda::setCurrentCUDAStream(stream_main);
} else {
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
}
te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb,
_ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
if (i > 0) {
// P2P communication chunk
int send_offset = comm_bytes * (i - 1);
int recv_offset = comm_bytes * (i - 1 + _tp_size);
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
CHECK_CUDA(cudaEventRecord(
_start_comm, (cudaStream_t) _stream_compute[(i - 1) % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_send, _start_comm, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, send_rank, (cudaStream_t) _stream_send);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, recv_rank, (cudaStream_t) _stream_recv);
}
}
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr,
_tp_size, _ubufs[0].numel(), (cudaStream_t) stream_main);
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
}
/*
** Copy input to _ubufs[0]
*/
......@@ -970,6 +1167,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
(cudaStream_t)stream_main));
}
}
torch::Tensor get_ubuf_output(int comm_type) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
......@@ -981,6 +1179,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int output_c_dim1 = _ubuf.size(1);
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
}
void set_ubuf_scale_inv(const torch::Tensor &scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); }
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return true; }
}; // UbufP2PCommOverlap
} // namespace ubuf
......@@ -109,26 +109,36 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG)
.value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P)
.value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P)
.value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS)
.value("ATOMIC_GEMM_AG", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG);
.value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P)
.value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int, torch::Tensor>())
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int, bool, torch::Tensor>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv)
.def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs)
.def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output);
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output)
.def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm)
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, bool, bool, int, torch::Tensor>())
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("atomic_gemm_overlap_ag", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag)
.def(py::init<torch::Tensor&, int, int, int, int, bool, bool, int, bool, bool, torch::Tensor>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs)
.def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag)
.def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output)
.def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf)
.def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm)
.def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap)
.def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv);
#else // NVTE_WITH_USERBUFFERS
m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations");
m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations");
......
......@@ -3666,3 +3666,34 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 grid(1);
consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
}
template <typename fp8type>
__global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale,
const int num_inputs, const int input_size) {
const size_t tid = threadIdx.x + blockDim.x * blockIdx.x;
fp8type *inputs_fp8 = reinterpret_cast<fp8type *>(inputs);
float accum_buf = static_cast<float>(inputs_fp8[tid]) * (*scale);
#pragma unroll
for (int i = 1; i < num_inputs; i++) {
accum_buf += static_cast<float>(inputs_fp8[tid + input_size * i]) * (*scale);
}
half *output_half = reinterpret_cast<half *>(output);
output_half[tid] = (half) accum_buf;
}
template <typename fp8type>
void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs,
int input_size, cudaStream_t stream) {
size_t num_threads = MAX_THREADS / 4;
size_t num_blocks = (input_size +num_threads - 1) / num_threads;
dim3 block(num_threads);
dim3 grid(num_blocks);
reduce_fp8_in_bf16_out_cuda<fp8type><<<grid, block, 0, stream>>>(
inputs, output, scale, num_inputs, input_size);
}
template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(
void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream);
template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(
void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream);
......@@ -305,4 +305,8 @@ void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream = 0);
void destroy_communicator(communicator *comm);
template <typename fp8type>
void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs,
int input_size, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
......@@ -129,13 +129,14 @@ def initialize_ub(
"qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
]
if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))):
fp8_buf.append ("proj_fprop")
fp8_buf += ["proj_fprop", "fc2_fprop"]
# Default overlap methods for layers
methods = {
"ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
"pipeline":["proj_fprop", "fc2_fprop"],
"bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
}
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
def get_method(name):
for method, names in methods.items():
......@@ -151,7 +152,28 @@ def initialize_ub(
set_sm_margin: int = 0,
num_splits: int = 4,
aggregate: int = 0,
atomic_gemm: int = 0,
is_reduce_scatter: int = 0,
) -> None:
if atomic_gemm:
warnings.warn(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
if is_reduce_scatter and method == "ring_exchange":
raise ValueError(
"Atomic GEMM is not supported for ReduceScatter with `ring_exchange` method."
)
if method == 'bulk':
warnings.warn(
"Atoimic GEMM not is supported for a bulk overlap."
"Defaulting to `atomic_gemm=False`."
)
atomic_gemm = 0
if not is_reduce_scatter and method == 'pipeline':
raise ValueError(
"`pipeline` overlap method is not supported for AllGather."
)
sample_buffer = torch.empty(
shape,
dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype,
......@@ -166,6 +188,8 @@ def initialize_ub(
set_sm_margin, # Set SM margin
aggregate, # Aggregate 2X GEMM chunks
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
is_reduce_scatter, # overlap with reduce scatter
atomic_gemm, # use a single GEMM with atomic-counters
torch.Tensor(), # empty tensor to pass to counters
)
else:
......@@ -178,6 +202,7 @@ def initialize_ub(
num_splits, # Number of communication splits
set_sm_margin, # Set SM margin
_NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams
atomic_gemm, # use a single GEMM with atomic-counters
torch.Tensor(), # empty tensor to pass to counters
)
_ub_communicators[name] = ub_obj
......@@ -191,6 +216,8 @@ def initialize_ub(
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0
is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0
add_ub(
name,
method,
......@@ -198,7 +225,9 @@ def initialize_ub(
cga_size,
set_sm_margin,
num_splits,
aggregate
aggregate,
atomic_gemm,
is_reduce_scatter,
)
else:
method = get_method(name)
......@@ -632,12 +661,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
gather_grad_output = row_parallel_mode and ctx.sequence_parallel
if gather_grad_output:
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
# No-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8:
if gather_grad_output:
if not ub_overlap_ag:
if not ctx.ub_overlap_ag:
grad_output_mat, _ = gather_along_first_dim(
grad_output_mat, ctx.tp_group
)
......@@ -656,7 +683,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
):
assert (
not ub_overlap_ag
not ctx.ub_overlap_ag
), "override_linear_precision.wgrad not supported with UB AG overlap"
grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
# FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
......@@ -665,7 +692,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias = grad_output_mat.sum(dim=0)
else:
grad_bias = None
if ub_overlap_ag:
if ctx.ub_overlap_ag:
grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
else:
grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
......@@ -676,7 +703,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
fp8_dtype_backward,
out=grad_output_c,
)
if not ub_overlap_ag:
if not ctx.ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
else:
......
......@@ -86,8 +86,7 @@ class _LayerNormLinear(torch.autograd.Function):
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_ag: bool,
ub_atomic_gemm_ag: bool,
ub_overlap_ag: bool,
ub_name: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
......@@ -106,12 +105,11 @@ class _LayerNormLinear(torch.autograd.Function):
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag or ub_atomic_gemm_ag:
if ub_overlap_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False
ub_atomic_gemm_ag = False
if ub_split_ag or ub_atomic_gemm_ag:
ub_overlap_ag = False
if ub_overlap_ag:
dim_size = list(inputmat.size())
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub(ub_name+"_fprop")
......@@ -119,8 +117,6 @@ class _LayerNormLinear(torch.autograd.Function):
else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -138,9 +134,13 @@ class _LayerNormLinear(torch.autograd.Function):
# Column Parallel Linear
ln_out_gathered = False
if ub_split_ag or ub_atomic_gemm_ag:
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
if ub_obj_lnout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif parallel_mode == "column" and sequence_parallel:
ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
......@@ -201,8 +201,6 @@ class _LayerNormLinear(torch.autograd.Function):
)
weight_t_fp8 = None
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
out, _ = tex.fp8_gemm(
weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
......@@ -217,9 +215,9 @@ class _LayerNormLinear(torch.autograd.Function):
bias=bias,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
ub_algo=ub_algo,
ub=ub_obj_lnout if (ub_split_ag or ub_atomic_gemm_ag) else None,
extra_output_tensor=ln_out if (ub_split_ag or ub_atomic_gemm_ag) else None,
ub_algo=ub_algo if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
else:
# Cast for native AMP
......@@ -243,9 +241,9 @@ class _LayerNormLinear(torch.autograd.Function):
get_workspace(),
bias=bias,
use_bias=use_bias,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
if is_grad_enabled:
......@@ -624,7 +622,6 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -737,8 +734,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_ag: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
) -> None:
super().__init__()
......@@ -758,23 +754,16 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_split_ag]):
self.ub_overlap_ag = ub_overlap_ag
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]):
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name
if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag:
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]):
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -1098,8 +1087,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.primary_weights_in_fp8,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_ag,
self.ub_atomic_gemm_ag,
self.ub_overlap_ag,
self.ub_name,
)
out = fwd_fn(*args)
......
......@@ -117,10 +117,8 @@ class _LayerNormMLP(torch.autograd.Function):
primary_weights_in_fp8: bool,
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_rs: bool,
ub_atomic_gemm_rs: bool,
ub_split_ag: bool,
ub_atomic_gemm_ag: bool,
ub_overlap_rs: bool,
ub_overlap_ag: bool,
gemm_gelu_fusion: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
......@@ -142,25 +140,17 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag or ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(tp_group)
tp_world_size = get_distributed_world_size(tp_group)
if ub_overlap_ag:
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
ub_split_ag = False
ub_atomic_gemm_ag = False
ub_overlap_ag = ub_split_ag or ub_atomic_gemm_ag
ub_overlap_ag = False
if ub_overlap_ag:
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_split_rs or ub_atomic_gemm_rs:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
ub_atomic_gemm_rs = False
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
......@@ -181,6 +171,10 @@ class _LayerNormMLP(torch.autograd.Function):
if ub_overlap_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
ln_out = torch.empty_like(ln_out)
if ub_obj_lnout.is_atomic_gemm():
ub_algo_ag = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo_ag = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif set_parallel_mode and sequence_parallel:
ln_out_gathered = True
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
......@@ -267,9 +261,6 @@ class _LayerNormMLP(torch.autograd.Function):
)
fc2_weight_t_fp8 = None
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo
# Perform FP8 GEMM
fp8_gemm_args = [
fc1_weight_fp8._data,
......@@ -287,7 +278,7 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias,
use_bias=use_fc1_bias,
use_split_accumulator=_2X_ACC_FPROP,
ub_algo=ub_algo,
ub_algo=ub_algo_ag if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
......@@ -321,13 +312,23 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = (
None, None, None, activation_dtype)
if ub_split_rs or ub_atomic_gemm_rs:
if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
if ub_obj_fc2out.is_p2p_overlap():
if ub_obj_fc2out.is_atomic_gemm():
ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
else:
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
if ub_obj_fc2out.is_atomic_gemm():
ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
if ub_obj_fc2out.is_fp8_ubuf():
fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT
......@@ -340,8 +341,6 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = tex.fp8_gemm(
fc2_weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
......@@ -357,9 +356,9 @@ class _LayerNormMLP(torch.autograd.Function):
use_bias=use_fc2_bias,
use_split_accumulator=_2X_ACC_FPROP,
out=fc2_out,
ub_algo=ub_algo,
ub=ub_obj_fc2out if ub_split_rs or ub_atomic_gemm_rs else None,
extra_output_tensor=rs_out if ub_split_rs or ub_atomic_gemm_rs else None,
ub_algo=ub_algo_rs if ub_overlap_rs else None,
ub=ub_obj_fc2out if ub_overlap_rs else None,
extra_output_tensor=rs_out if ub_overlap_rs else None,
out_index=fc2_out_index,
fp8_meta_tensor = fc2_meta_tensor,
D_dtype = fc2_te_type,
......@@ -395,9 +394,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc1_bias,
use_bias=(not bias_gelu_nvfusion) and use_fc1_bias,
gelu=not bias_gelu_nvfusion and (activation == 'gelu'),
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
ub=ub_obj_lnout if ub_split_ag else None,
extra_output_tensor=ln_out if ub_split_ag else None,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None,
ub=ub_obj_lnout if ub_overlap_ag else None,
extra_output_tensor=ln_out if ub_overlap_ag else None,
)
if not is_grad_enabled:
clear_tensor_data(ln_out_total)
......@@ -427,13 +426,17 @@ class _LayerNormMLP(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
torch.max(-amin, amax).float()
if ub_split_rs:
if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
fc2_out = ub_obj_fc2out.get_ubuf_output(1)
dim_size = list(gelu_out.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
if ub_obj_fc2out.is_p2p_overlap():
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
else:
dim_size = list(gelu_out.size())
dim_size[1] = fc2_weight.size(0)
......@@ -446,9 +449,9 @@ class _LayerNormMLP(torch.autograd.Function):
bias=fc2_bias,
use_bias=use_fc2_bias,
out=fc2_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_fc2out if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
ub_algo=ub_algo_rs if ub_overlap_rs else None,
ub=ub_obj_fc2out if ub_overlap_rs else None,
extra_output_tensor=rs_out if ub_overlap_rs else None,
)
if not is_grad_enabled:
clear_tensor_data(gelu_out)
......@@ -515,13 +518,12 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.zero_centered_gamma = zero_centered_gamma
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_split_ag = ub_split_ag
ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
ctx.ub_overlap_ag = ub_overlap_ag
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
# Row Parallel Linear
if ub_split_rs or ub_atomic_gemm_rs:
if ub_overlap_rs:
fc2_out = rs_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
......@@ -590,18 +592,19 @@ class _LayerNormMLP(torch.autograd.Function):
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("fc1_dgrad")
ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
if ub_overlap_ag:
if ctx.ub_overlap_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_split_ag = False
ctx.ub_overlap_ag = False
ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag
if ub_overlap_ag:
if ctx.ub_overlap_ag:
dim_size = list(grad_outputs[0].size())
dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub("fc2_dgrad")
if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess
(
......@@ -645,8 +648,6 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False
)
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
# FC2 DGRAD; Unconditional
fc2_dgrad, _ = tex.fp8_gemm(
fc2_weight_t_fp8._data,
......@@ -660,10 +661,10 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=ub_algo,
ub=ctx.ub_obj_gradout if ub_overlap_ag else None,
ub_algo=ub_algo if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
)
if ub_overlap_ag:
if ctx.ub_overlap_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
clear_tensor_data(grad_output_c)
......@@ -801,8 +802,9 @@ class _LayerNormMLP(torch.autograd.Function):
gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == 'gelu'),
grad=True,
gelu_input=fc1_out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \
if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
)
# FC2 WGRAD
......@@ -1070,8 +1072,6 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -1194,10 +1194,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
device: Union[torch.device, str] = "cuda",
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_ag: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
) -> None:
super().__init__()
......@@ -1218,29 +1216,18 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_bulk_wgrad = ub_bulk_wgrad
self.ub_bulk_dgrad = ub_bulk_dgrad
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
self.gemm_gelu_fusion = (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and
self.activation == 'gelu' and self.ub_split_ag)
if (ub_bulk_wgrad # pylint: disable=too-many-boolean-expressions
or ub_bulk_dgrad
or ub_split_rs
or ub_split_ag
or ub_atomic_gemm_rs
or ub_atomic_gemm_ag):
self.gemm_gelu_fusion = \
(bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and
self.activation == 'gelu' and not get_ub("fc1_fprop").is_atomic_gemm())
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag]):
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -1490,10 +1477,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.primary_weights_in_fp8,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_rs,
self.ub_atomic_gemm_rs,
self.ub_split_ag,
self.ub_atomic_gemm_ag,
self.ub_overlap_rs,
self.ub_overlap_ag,
self.gemm_gelu_fusion,
)
out = fwd_fn(*args)
......
......@@ -3,7 +3,6 @@
# See LICENSE for license information.
"""Linear API"""
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any
import torch
......@@ -79,10 +78,8 @@ class _Linear(torch.autograd.Function):
parallel_mode: Union[str, None],
is_grad_enabled: bool,
primary_weights_in_fp8: bool,
ub_split_rs: bool,
ub_split_ag: bool,
ub_atomic_gemm_rs: bool,
ub_atomic_gemm_ag: bool,
ub_overlap_rs: bool,
ub_overlap_ag: bool,
ub_name: str
) -> torch.Tensor:
# Make sure input dimensions are compatible
......@@ -94,14 +91,8 @@ class _Linear(torch.autograd.Function):
assert_dim_for_fp8_exec(weight)
update_fp8_weights = is_first_microbatch is None or is_first_microbatch
if ub_split_rs or ub_atomic_gemm_rs:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
ub_atomic_gemm_rs = False
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
assert fp8, "AtomicGemm overlap supported only for FP8 GEMM."
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
# Cast input to expected dtype
inputmat = cast_if_needed(inputmat, activation_dtype)
......@@ -180,14 +171,23 @@ class _Linear(torch.autograd.Function):
proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
None, None, None, activation_dtype)
if ub_split_rs or ub_atomic_gemm_rs:
if ub_overlap_rs:
ub_obj_projout = get_ub(ub_name+"_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap():
if ub_obj_projout.is_atomic_gemm():
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
if ub_obj_projout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
if ub_obj_projout.is_fp8_ubuf():
proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
meta_tensor = fp8_meta["scaling_fwd"]
......@@ -199,8 +199,6 @@ class _Linear(torch.autograd.Function):
dim_size[1] = weight.size(0)
out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo
_ = fp8_gemm(
weight_fp8._data,
fp8_meta["scaling_fwd"].scale_inv,
......@@ -216,9 +214,9 @@ class _Linear(torch.autograd.Function):
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
out=out,
ub_algo=ub_algo,
ub=ub_obj_projout if (ub_split_rs or ub_atomic_gemm_rs) else None,
extra_output_tensor=rs_out if (ub_split_rs or ub_atomic_gemm_rs) else None,
ub_algo=ub_algo if ub_overlap_rs else None,
ub=ub_obj_projout if ub_overlap_rs else None,
extra_output_tensor=rs_out if ub_overlap_rs else None,
out_index=proj_out_index,
fp8_meta_tensor = meta_tensor,
D_dtype = proj_out_tetype,
......@@ -238,13 +236,17 @@ class _Linear(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.max(-amin, amax).float()
if ub_split_rs:
ub_obj_projout = get_ub("proj_fprop")
if ub_overlap_rs:
ub_obj_projout = get_ub(ub_name+"_fprop")
out = ub_obj_projout.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group)
dim_size[1] = weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj_projout.is_p2p_overlap():
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
else:
dim_size = list(inputmat_total.size())
dim_size[1] = weight.size(0)
......@@ -258,9 +260,9 @@ class _Linear(torch.autograd.Function):
bias=bias,
use_bias=use_bias,
out=out,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
ub=ub_obj_projout if ub_split_rs else None,
extra_output_tensor=rs_out if ub_split_rs else None,
ub_algo=ub_algo if ub_overlap_rs else None,
ub=ub_obj_projout if ub_overlap_rs else None,
extra_output_tensor=rs_out if ub_overlap_rs else None,
)
if is_grad_enabled:
......@@ -307,14 +309,13 @@ class _Linear(torch.autograd.Function):
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.ub_split_ag = ub_split_ag
ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag
ctx.ub_overlap_ag = ub_overlap_ag
ctx.ub_name = ub_name
ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
# Row Parallel Linear
if ub_split_rs or ub_atomic_gemm_rs:
if ub_overlap_rs:
out = rs_out
elif parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
......@@ -350,16 +351,16 @@ class _Linear(torch.autograd.Function):
weight_t_fp8 = weight.transpose(
update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy",
)
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
if tp_world_size == 1:
ctx.ub_split_ag = False
ctx.ub_atomic_gemm_ag = False
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
tp_world_size = get_distributed_world_size(ctx.tp_group)
ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
if ctx.ub_overlap_ag:
dim_size = list(grad_output.size())
dim_size[0] = dim_size[0] * tp_world_size
ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad")
if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
(
grad_output,
grad_output_c,
......@@ -397,8 +398,6 @@ class _Linear(torch.autograd.Function):
ctx.fp8_meta["recipe"], fprop_tensor=False
)
ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None
ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo
if ctx.requires_dgrad:
if ctx.fp8:
dgrad, _ = fp8_gemm(
......@@ -413,8 +412,8 @@ class _Linear(torch.autograd.Function):
ctx.activation_dtype,
get_workspace(),
use_split_accumulator=_2X_ACC_DGRAD,
ub_algo=ub_algo,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag else None,
ub_algo=ub_algo if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
)
else:
dgrad, _, _ = gemm(
......@@ -424,8 +423,9 @@ class _Linear(torch.autograd.Function):
get_workspace(),
layout="NN",
grad=True,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \
if ctx.ub_overlap_ag else None,
ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
)
# Overlap dgrad-RS/AR with wgrad
......@@ -442,7 +442,7 @@ class _Linear(torch.autograd.Function):
if ctx.fp8:
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag:
if ctx.ub_overlap_ag:
grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
if inputmat_t_total is None:
inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward)
......@@ -542,8 +542,6 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
None,
)
......@@ -629,10 +627,8 @@ class Linear(TransformerEngineBaseModule):
parallel_mode: Optional[str] = None,
parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
device: Union[torch.device, str] = "cuda",
ub_split_rs: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
) -> None:
super().__init__()
......@@ -645,28 +641,18 @@ class Linear(TransformerEngineBaseModule):
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.ub_split_rs = ub_split_rs
self.ub_split_ag = ub_split_ag
self.ub_atomic_gemm_rs = ub_atomic_gemm_rs
self.ub_atomic_gemm_ag = ub_atomic_gemm_ag
if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]):
self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag
if ub_overlap_rs or ub_overlap_ag:
assert ub_name is not None, "Userbuffer name [string] is not set."
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
self.ub_name = ub_name
self.get_rng_state_tracker = get_rng_state_tracker
if device == 'meta':
assert parameters_split is None, ("Cannot split module parameters "
"on 'meta' device.")
if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......@@ -930,10 +916,8 @@ class Linear(TransformerEngineBaseModule):
self.parallel_mode,
torch.is_grad_enabled(),
self.primary_weights_in_fp8,
self.ub_split_rs,
self.ub_split_ag,
self.ub_atomic_gemm_rs,
self.ub_atomic_gemm_ag,
self.ub_overlap_rs,
self.ub_overlap_ag,
self.ub_name,
)
out = linear_fn(*args)
......
......@@ -259,10 +259,8 @@ class TransformerLayer(torch.nn.Module):
ub_tp_comm_overlap: bool = False,
ub_bulk_wgrad: bool = True,
ub_bulk_dgrad: bool = True,
ub_split_ag: bool = True,
ub_split_rs: bool = True,
ub_atomic_gemm_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_overlap_ag: bool = True,
ub_overlap_rs: bool = True,
bias: bool = True,
activation: str = 'gelu',
normalization: str = "LayerNorm",
......@@ -282,21 +280,8 @@ class TransformerLayer(torch.nn.Module):
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
ub_split_ag = ub_tp_comm_overlap and ub_split_ag
ub_split_rs = ub_tp_comm_overlap and ub_split_rs
ub_atomic_gemm_rs = ub_tp_comm_overlap and ub_atomic_gemm_rs
assert (
not (ub_split_rs and ub_atomic_gemm_rs)
), "Only one type of RS overlap ub_split_rs/ub_atomic_gemm_rs should be enabled."
ub_atomic_gemm_ag = ub_tp_comm_overlap and ub_atomic_gemm_ag
assert (
not (ub_split_ag and ub_atomic_gemm_ag)
), "Only one type of AG overlap ub_split_ag/ub_atomic_gemm_ag should be enabled."
if ub_atomic_gemm_rs or ub_atomic_gemm_ag:
warnings.warn(
"Atomic gemm uses a beta API from cublas and is not tested for all use cases."
)
ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag
ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number
......@@ -370,10 +355,8 @@ class TransformerLayer(torch.nn.Module):
"qkv_weight_interleaved" : qkv_weight_interleaved,
"ub_bulk_wgrad" : ub_bulk_wgrad,
"ub_bulk_dgrad" : ub_bulk_dgrad,
"ub_split_ag" : ub_split_ag,
"ub_split_rs" : ub_split_rs,
"ub_atomic_gemm_rs" : ub_atomic_gemm_rs,
"ub_atomic_gemm_ag" : ub_atomic_gemm_ag,
"ub_overlap_ag" : ub_overlap_ag,
"ub_overlap_rs" : ub_overlap_rs,
"qkv_format" : self.attn_input_format,
}
......@@ -427,10 +410,8 @@ class TransformerLayer(torch.nn.Module):
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag,
activation=activation,
normalization=normalization,
device=device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment