Unverified Commit 8c0a0c93 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

DGRAD_RS UB overlap Bug fixes (#1004)



* DGRAD_RS UB overlap Bug fixes
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e39674b9
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
} while (0) } while (0)
using namespace torch::indexing; using namespace torch::indexing;
namespace ubuf { namespace ubuf {
/* /*
...@@ -324,47 +325,48 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -324,47 +325,48 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/, _math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/,
counter); counter);
for (int i = 0; i < _num_splits; i++) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); B_type, fp8_type, for (int i = 0; i < _num_splits; i++) {
if (env_p != nullptr && env_p[0] == '1') { const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
if (i == _num_splits - 1) { if (env_p != nullptr && env_p[0] == '1') {
_ub_comm->sms = UB_MAX_SM; if (i == _num_splits - 1) {
} _ub_comm->sms = UB_MAX_SM;
if (_ubuf.element_size() == 1) { }
assert(_ubuf_scale_inv_initialized); if (_ubuf.element_size() == 1) {
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); assert(_ubuf_scale_inv_initialized);
reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
&counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm); rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m,
} else { _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, } else {
_num_splits, &counter_ptr[i], _ub_comm, reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk,
(cudaStream_t)_stream_comm); n, m, _num_splits, &counter_ptr[i], _ub_comm,
} (cudaStream_t)_stream_comm);
} else if (env_p != nullptr && env_p[0] == '2') { }
if (_ubuf.element_size() == 1) { } else if (env_p != nullptr && env_p[0] == '2') {
assert(_ubuf_scale_inv_initialized); if (_ubuf.element_size() == 1) {
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); assert(_ubuf_scale_inv_initialized);
reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>( float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm); rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
} else { counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, } else {
m, _num_splits, counter_ptr, _ub_comm, reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk,
(cudaStream_t)_stream_comm); n, m, _num_splits, counter_ptr, _ub_comm,
} (cudaStream_t)_stream_comm);
break; }
} else { break;
consumer(counter_ptr, i, (cudaStream_t)_stream_comm); } else {
// if (i == _num_splits-1) { consumer(counter_ptr, i, (cudaStream_t)_stream_comm);
// _ub_comm->sms = UB_MAX_SM; // if (i == _num_splits-1) {
// } // _ub_comm->sms = UB_MAX_SM;
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, // }
_ub_comm, (cudaStream_t)_stream_comm); reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
} _ub_comm, (cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size(); rs_output_ptr += m_chunk * rs_output.element_size();
} });
_ub_comm->sms = ori_sms; _ub_comm->sms = ori_sms;
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
...@@ -422,111 +424,115 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -422,111 +424,115 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
if (gemm_overlap) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); B_type, fp8_type,
torch::Tensor output_chunk = if (gemm_overlap) {
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); torch::Tensor input_a_chunk =
torch::Tensor workspace_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); torch::Tensor output_chunk =
at::cuda::setCurrentCUDAStream(_stream_compute[0]); torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, torch::Tensor workspace_chunk =
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
for (int i = 1; i < _num_splits; i++) { transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
input_a_chunk_ptr += input_a_chunk_size * B.element_size(); grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); _math_sms);
torch::Tensor input_a_chunk = for (int i = 1; i < _num_splits; i++) {
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); input_a_chunk_ptr += input_a_chunk_size * B.element_size();
torch::Tensor output_chunk = output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk = torch::Tensor input_a_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
{workspace_size_chunk}, workspace.options()); torch::Tensor output_chunk =
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, torch::Tensor workspace_chunk = torch::from_blob(
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, {workspace_size_chunk}, workspace.options());
_math_sms); at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
CHECK_CUDA(cudaEventRecord( transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
_start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()])); grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); _math_sms);
// Communication chunk CHECK_CUDA(cudaEventRecord(
if (_ubuf.element_size() == 1) { _start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]));
assert(_ubuf_scale_inv_initialized); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( // Communication chunk
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, if (_ubuf.element_size() == 1) {
_ub_comm, (cudaStream_t)_stream_comm); assert(_ubuf_scale_inv_initialized);
} else { float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
m_chunk, n, m, _ub_comm, rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n,
(cudaStream_t)_stream_comm); m, _ub_comm, (cudaStream_t)_stream_comm);
} } else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
rs_output_ptr += m_chunk * rs_output.element_size(); (i - 1) * output_chunk_size, m_chunk, n, m,
} _ub_comm, (cudaStream_t)_stream_comm);
int last_compute_stream_id = }
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA( rs_output_ptr += m_chunk * rs_output.element_size();
cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); }
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
// Last communication chunk with max SM CHECK_CUDA(
_ub_comm->sms = UB_MAX_SM; cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id]));
if (_ubuf.element_size() == 1) { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk,
n, m, _ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
}
} else {
for (int i = 0; i < _num_splits; i++) {
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(_start_comm,
(cudaStream_t)_stream_compute[i % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk. Uses MAX_SM at the last chunk // Last communication chunk with max SM
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM; _ub_comm->sms = UB_MAX_SM;
} if (_ubuf.element_size() == 1) {
if (_ubuf.element_size() == 1) { assert(_ubuf_scale_inv_initialized);
assert(_ubuf_scale_inv_initialized); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size,
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);
_ub_comm, (cudaStream_t)_stream_comm); } else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n,
m, _ub_comm, (cudaStream_t)_stream_comm);
}
} else { } else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, for (int i = 0; i < _num_splits; i++) {
m_chunk, n, m, _ub_comm, torch::Tensor input_a_chunk =
(cudaStream_t)_stream_comm); torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
} torch::Tensor output_chunk =
rs_output_ptr += m_chunk * rs_output.element_size(); torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
input_a_chunk_ptr += input_a_chunk_size * B.element_size(); torch::Tensor workspace_chunk = torch::from_blob(
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
} {workspace_size_chunk}, workspace.options());
} at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(_start_comm,
(cudaStream_t)_stream_compute[i % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm,
(cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
});
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
...@@ -1051,18 +1057,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1051,18 +1057,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks // Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { B_type, fp8_type,
assert(_ubuf_scale_inv_initialized); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); assert(_ubuf_scale_inv_initialized);
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
_tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
} else { reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
torch::Tensor reduce_buf = torch::from_blob( _ubufs[0].numel(), (cudaStream_t)stream_main);
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); } else {
torch::sum_out(rs_output, reduce_buf, 0); torch::Tensor reduce_buf = torch::from_blob(
} reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
});
} }
/* /*
...@@ -1145,18 +1153,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1145,18 +1153,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks // Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { B_type, fp8_type,
assert(_ubuf_scale_inv_initialized); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); assert(_ubuf_scale_inv_initialized);
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
_tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
} else { reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
torch::Tensor reduce_buf = torch::from_blob( _ubufs[0].numel(), (cudaStream_t)stream_main);
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); } else {
torch::sum_out(rs_output, reduce_buf, 0); torch::Tensor reduce_buf = torch::from_blob(
} reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
});
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
......
...@@ -1890,11 +1890,18 @@ template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( ...@@ -1890,11 +1890,18 @@ template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>( template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { __global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
atomicAdd_system(flagptr, 1); atomicAdd_system(flagptr, 1);
......
...@@ -233,7 +233,9 @@ def initialize_ub( ...@@ -233,7 +233,9 @@ def initialize_ub(
wgrad_name = name.replace("dgrad", "wgrad") wgrad_name = name.replace("dgrad", "wgrad")
assert wgrad_name not in ub_cfgs assert wgrad_name not in ub_cfgs
layers_reduce_scatter_overlap.remove(wgrad_name) layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name) layers_reduce_scatter_overlap.append(name)
methods["pipeline"].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
if ub_cfgs is not None and name in ub_cfgs: if ub_cfgs is not None and name in ub_cfgs:
......
...@@ -184,7 +184,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -184,7 +184,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
out=ln_out_fp8, out=ln_out_fp8,
) )
ln_out = ln_out_fp8 ln_out = torch.empty_like(ln_out_fp8)
else: else:
ln_out_total = tex.cast_to_fp8( ln_out_total = tex.cast_to_fp8(
ln_out_total, ln_out_total,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment