Unverified Commit a6db82d9 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[C/PyTorch] Fixing incorrect use of TYPE_SWITCH_FP8_ONLY in GEMM + reduce-scatter overlap (#1023)



* FP8 type switch macro now wraps only the FP8 kernel to avoid invalid type errors
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c57a81f0
......@@ -19,6 +19,7 @@
#include <torch/extension.h>
#include <torch/types.h>
#include "common/common.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "extensions.h"
......@@ -325,8 +326,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, _num_splits /*m_split*/, 0 /*n_split*/, true /*gemm_producer*/,
counter);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type, for (int i = 0; i < _num_splits; i++) {
for (int i = 0; i < _num_splits; i++) {
const char *env_p = std::getenv("NVTE_RS_STRIDED_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
if (i == _num_splits - 1) {
......@@ -335,24 +335,28 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, m,
_num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm);
_num_splits, &counter_ptr[i], _ub_comm, (cudaStream_t)_stream_comm););
} else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk,
n, m, _num_splits, &counter_ptr[i], _ub_comm,
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_num_splits, &counter_ptr[i], _ub_comm,
(cudaStream_t)_stream_comm);
}
} else if (env_p != nullptr && env_p[0] == '2') {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm);
counter_ptr, _ub_comm, (cudaStream_t)_stream_comm););
} else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk,
n, m, _num_splits, counter_ptr, _ub_comm,
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n,
m, _num_splits, counter_ptr, _ub_comm,
(cudaStream_t)_stream_comm);
}
break;
......@@ -366,7 +370,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
}
rs_output_ptr += m_chunk * rs_output.element_size();
});
}
_ub_comm->sms = ori_sms;
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
......@@ -424,20 +428,16 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
assert(pre_gelu_out.numel() == 0);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
if (gemm_overlap) {
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
......@@ -447,13 +447,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk = torch::from_blob(
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(
......@@ -464,13 +464,15 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n,
m, _ub_comm, (cudaStream_t)_stream_comm);
m, _ub_comm, (cudaStream_t)_stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(i - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm,
(cudaStream_t)_stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
......@@ -486,13 +488,15 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n,
m, _ub_comm, (cudaStream_t)_stream_comm);
(_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
}
} else {
for (int i = 0; i < _num_splits; i++) {
......@@ -500,13 +504,13 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk = torch::from_blob(
workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type,
transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out,
grad, workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(_start_comm,
......@@ -520,9 +524,11 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, d_scale_inv_ptr, _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
_ub_comm, (cudaStream_t)_stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm,
......@@ -532,7 +538,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
});
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
......@@ -1057,20 +1063,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main);
_ubufs[0].numel(), (cudaStream_t)stream_main););
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
});
}
}
/*
......@@ -1153,20 +1159,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
B_type, fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, _tp_size,
_ubufs[0].numel(), (cudaStream_t)stream_main);
_ubufs[0].numel(), (cudaStream_t)stream_main););
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
});
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment