Unverified Commit 8c0a0c93 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

DGRAD_RS UB overlap Bug fixes (#1004)



* DGRAD_RS UB overlap Bug fixes
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e39674b9
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
} while (0) } while (0)
using namespace torch::indexing; using namespace torch::indexing;
namespace ubuf { namespace ubuf {
/* /*
...@@ -324,7 +325,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -324,7 +325,8 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/, _math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/,
counter); counter);
for (int i = 0; i < _num_splits; i++) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, for (int i = 0; i < _num_splits; i++) {
const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC"); const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') { if (env_p != nullptr && env_p[0] == '1') {
if (i == _num_splits - 1) { if (i == _num_splits - 1) {
...@@ -333,24 +335,24 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -333,24 +335,24 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m,
&counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm); _num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);
} else { } else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk,
_num_splits, &counter_ptr[i], _ub_comm, n, m, _num_splits, &counter_ptr[i], _ub_comm,
(cudaStream_t)_stream_comm); (cudaStream_t)_stream_comm);
} }
} else if (env_p != nullptr && env_p[0] == '2') { } else if (env_p != nullptr && env_p[0] == '2') {
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>( reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm); counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);
} else { } else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk,
m, _num_splits, counter_ptr, _ub_comm, n, m, _num_splits, counter_ptr, _ub_comm,
(cudaStream_t)_stream_comm); (cudaStream_t)_stream_comm);
} }
break; break;
...@@ -364,7 +366,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -364,7 +366,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
} }
rs_output_ptr += m_chunk * rs_output.element_size(); rs_output_ptr += m_chunk * rs_output.element_size();
} });
_ub_comm->sms = ori_sms; _ub_comm->sms = ori_sms;
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0])); CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
...@@ -422,16 +424,20 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -422,16 +424,20 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(pre_gelu_out.numel() == 0); assert(pre_gelu_out.numel() == 0);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
if (gemm_overlap) { if (gemm_overlap) {
torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk = torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk = torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[0]); at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms); grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
for (int i = 1; i < _num_splits; i++) { for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size(); input_a_chunk_ptr += input_a_chunk_size * B.element_size();
...@@ -441,13 +447,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -441,13 +447,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk = torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk = torch::Tensor workspace_chunk = torch::from_blob(
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options()); {workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms); _math_sms);
CHECK_CUDA(cudaEventRecord( CHECK_CUDA(cudaEventRecord(
...@@ -458,13 +464,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -458,13 +464,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n,
_ub_comm, (cudaStream_t)_stream_comm); m, _ub_comm, (cudaStream_t)_stream_comm);
} else { } else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
m_chunk, n, m, _ub_comm, (i - 1) * output_chunk_size, m_chunk, n, m,
(cudaStream_t)_stream_comm); _ub_comm, (cudaStream_t)_stream_comm);
} }
rs_output_ptr += m_chunk * rs_output.element_size(); rs_output_ptr += m_chunk * rs_output.element_size();
...@@ -480,13 +486,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -480,13 +486,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size,
n, m, _ub_comm, (cudaStream_t)_stream_comm); m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);
} else { } else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m, (_num_splits - 1) * output_chunk_size, m_chunk, n,
_ub_comm, (cudaStream_t)_stream_comm); m, _ub_comm, (cudaStream_t)_stream_comm);
} }
} else { } else {
for (int i = 0; i < _num_splits; i++) { for (int i = 0; i < _num_splits; i++) {
...@@ -494,13 +500,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -494,13 +500,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options()); torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk = torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options()); torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk = torch::Tensor workspace_chunk = torch::from_blob(
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options()); {workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms); _math_sms);
CHECK_CUDA(cudaEventRecord(_start_comm, CHECK_CUDA(cudaEventRecord(_start_comm,
...@@ -514,7 +520,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -514,7 +520,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
if (_ubuf.element_size() == 1) { if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m, rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm); _ub_comm, (cudaStream_t)_stream_comm);
} else { } else {
...@@ -526,7 +532,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -526,7 +532,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
input_a_chunk_ptr += input_a_chunk_size * B.element_size(); input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
} }
} });
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
...@@ -1051,18 +1057,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1051,18 +1057,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks // Reduce GEMM output chunks
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr()); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main); _ubufs[0].numel(), (cudaStream_t)stream_main);
} else { } else {
torch::Tensor reduce_buf = torch::from_blob( torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0); torch::sum_out(rs_output, reduce_buf, 0);
} });
} }
/* /*
...@@ -1145,18 +1153,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { ...@@ -1145,18 +1153,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks // Reduce GEMM output chunks
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr()); char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized); assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr()); float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_tp_size, _ubufs[0].numel(), (cudaStream_t)stream_main); _ubufs[0].numel(), (cudaStream_t)stream_main);
} else { } else {
torch::Tensor reduce_buf = torch::from_blob( torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0); torch::sum_out(rs_output, reduce_buf, 0);
} });
for (size_t i = 0; i < _stream_compute.size(); i++) { for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i])); CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
......
...@@ -1890,11 +1890,18 @@ template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( ...@@ -1890,11 +1890,18 @@ template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>( template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e4m3>(
void *output, float *scale, const int handler, const int offset, const int rowelements, void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in, const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream); const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
void *output, float *scale, const int handler, const int offset, const int rowelements,
const int colelements, const int strideelements_out, const int strideelements_in,
const int numchunks, void *counters, communicator *comm, cudaStream_t stream);
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { __global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
atomicAdd_system(flagptr, 1); atomicAdd_system(flagptr, 1);
......
...@@ -233,7 +233,9 @@ def initialize_ub( ...@@ -233,7 +233,9 @@ def initialize_ub(
wgrad_name = name.replace("dgrad", "wgrad") wgrad_name = name.replace("dgrad", "wgrad")
assert wgrad_name not in ub_cfgs assert wgrad_name not in ub_cfgs
layers_reduce_scatter_overlap.remove(wgrad_name) layers_reduce_scatter_overlap.remove(wgrad_name)
layers_all_gather_overlap.remove(name)
layers_reduce_scatter_overlap.append(name) layers_reduce_scatter_overlap.append(name)
methods["pipeline"].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
if ub_cfgs is not None and name in ub_cfgs: if ub_cfgs is not None and name in ub_cfgs:
......
...@@ -184,7 +184,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -184,7 +184,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_dtype_forward, fp8_dtype_forward,
out=ln_out_fp8, out=ln_out_fp8,
) )
ln_out = ln_out_fp8 ln_out = torch.empty_like(ln_out_fp8)
else: else:
ln_out_total = tex.cast_to_fp8( ln_out_total = tex.cast_to_fp8(
ln_out_total, ln_out_total,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment