Unverified Commit 8c0a0c93 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

DGRAD_RS UB overlap Bug fixes (#1004)



* DGRAD_RS UB overlap Bug fixes
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e39674b9
......@@ -37,6 +37,7 @@
} while (0)
using namespace torch::indexing;
namespace ubuf {
/*
......@@ -324,47 +325,48 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/,
counter);
for (int i = 0; i < _num_splits; i++) {
const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits,
&counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_num_splits, &counter_ptr[i], _ub_comm,
(cudaStream_t)_stream_comm);
}
} else if (env_p != nullptr && env_p[0] == '2') {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n,
m, _num_splits, counter_ptr, _ub_comm,
(cudaStream_t)_stream_comm);
}
break;
} else {
consumer(counter_ptr, i, (cudaStream_t)_stream_comm);
// if (i == _num_splits-1) {
// _ub_comm->sms = UB_MAX_SM;
// }
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, for (int i = 0; i < _num_splits; i++) {
const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m,
_num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk,
n, m, _num_splits, &counter_ptr[i], _ub_comm,
(cudaStream_t)_stream_comm);
}
} else if (env_p != nullptr && env_p[0] == '2') {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk,
n, m, _num_splits, counter_ptr, _ub_comm,
(cudaStream_t)_stream_comm);
}
break;
} else {
consumer(counter_ptr, i, (cudaStream_t)_stream_comm);
// if (i == _num_splits-1) {
// _ub_comm->sms = UB_MAX_SM;
// }
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
rs_output_ptr += m_chunk * rs_output.element_size();
});
_ub_comm->sms = ori_sms;
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
......@@ -422,111 +424,115 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(pre_gelu_out.numel() == 0);
if (gemm_overlap) {
torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(
_start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm,
(cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM;
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk,
n, m, _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
}
} else {
for (int i = 0; i < _num_splits; i++) {
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(_start_comm,
(cudaStream_t)_stream_compute[i % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
if (gemm_overlap) {
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk = torch::from_blob(
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(
_start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n,
m, _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(i - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits - 1) {
// Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n,
m, _ub_comm, (cudaStream_t)_stream_comm);
}
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm,
(cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
}
for (int i = 0; i < _num_splits; i++) {
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk = torch::from_blob(
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(_start_comm,
(cudaStream_t)_stream_compute[i % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm,
(cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
});
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
......@@ -1051,18 +1057,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr,
_tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main);
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main);
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
});
}
/*
......@@ -1145,18 +1153,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr,
_tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main);
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main);
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
});
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
......
......@@ -1890,11 +1890,18 @@ template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
atomicAdd_system(flagptr, 1);
......
......@@ -233,7 +233,9 @@ def initialize_ub(
wgrad_name = name.replace("dgrad", "wgrad")
assert wgrad_name not in ub_cfgs
layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name)
methods["pipeline"].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
if ub_cfgs is not None and name in ub_cfgs:
......
......@@ -184,7 +184,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward,
out=ln_out_fp8,
)
ln_out = ln_out_fp8
ln_out = torch.empty_like(ln_out_fp8)
else:
ln_out_total = tex.cast_to_fp8(
ln_out_total,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment