Unverified Commit 90f3c9ad authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Avoid select op in PyTorch extensions (#865)



* Avoid select operation in cast-transpose extension
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid select operation in cast-transpose-dbias extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid select op in LayerNorm and RMSNorm
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter errors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent bbb22395
......@@ -35,13 +35,16 @@ def layernorm_fwd_fp8(
weight,
bias,
eps,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.scale,
ln_out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
otype,
sm_margin,
zero_centered_gamma
zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
)
return tex.layernorm_fwd_fp8(
......@@ -49,12 +52,15 @@ def layernorm_fwd_fp8(
weight,
bias,
eps,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
otype,
sm_margin,
zero_centered_gamma
zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
)
......@@ -124,25 +130,31 @@ def rmsnorm_fwd_fp8(
inp,
weight,
eps,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.scale,
rmsnorm_out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
otype,
sm_margin,
zero_centered_gamma
zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
)
return tex.rmsnorm_fwd_fp8(
inp,
weight,
eps,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
otype,
sm_margin,
zero_centered_gamma
zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
)
......
......@@ -43,12 +43,15 @@ def fp8_cast_transpose_fused(
tex.fused_cast_transpose_noop(
inp,
noop_flag,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
cast_out,
transpose_out,
otype,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
)
if return_outputs:
......@@ -65,10 +68,13 @@ def fp8_cast_transpose_bgrad_fused(
"""Cast + Transpose + BGRAD with FP8 output"""
return tex.fused_cast_transpose_bgrad(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
otype,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
)
......@@ -82,11 +88,14 @@ def fp8_transpose_bgrad_fused(
"""Transpose + BGRAD with FP8 output"""
return tex.fused_fp8_transpose_bgrad(
inp,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
otype,
TE_DType[grad_bias_type],
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
)
......@@ -101,8 +110,11 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
return tex.fused_cast_transpose_bgrad_dgelu(
grad_output,
gelu_input,
fp8_meta_tensor.scale[fp8_tensor],
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
otype,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
)
......@@ -138,10 +138,15 @@ at::Tensor allocateTorchTensor(int M,
at::CUDA(GetATenDType(dtype)));
}
void *getDataPtr(at::Tensor t) {
if (t.numel() > 0) {
return t.data_ptr();
} else {
return nullptr;
void* getDataPtr(at::Tensor tensor, int offset) {
void* dptr = nullptr;
if (tensor.numel() > 0) {
dptr = tensor.data_ptr();
}
if (dptr != nullptr && offset != 0) {
char* char_ptr = reinterpret_cast<char*>(dptr);
char_ptr += offset * tensor.element_size();
dptr = reinterpret_cast<void*>(char_ptr);
}
return dptr;
}
......@@ -186,6 +186,6 @@ at::Tensor allocateTorchTensor(int M,
transformer_engine::DType dtype
);
void *getDataPtr(at::Tensor t);
void* getDataPtr(at::Tensor tensor, int offset = 0);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
......@@ -263,7 +263,10 @@ void fused_cast_transpose_noop(at::Tensor input,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
transformer_engine::DType otype,
int scale_offset = 0,
int amax_offset = 0,
int scale_inv_offset = 0
);
......@@ -271,7 +274,10 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
transformer_engine::DType otype,
int scale_offset = 0,
int amax_offset = 0,
int scale_inv_offset = 0
);
......@@ -280,7 +286,10 @@ std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
transformer_engine::DType grad_bias_type
transformer_engine::DType grad_bias_type,
int scale_offset = 0,
int amax_offset = 0,
int scale_inv_offset = 0
);
......@@ -289,7 +298,10 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
transformer_engine::DType otype,
int scale_offset = 0,
int amax_offset = 0,
int scale_inv_offset = 0
);
......@@ -429,7 +441,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
);
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
......@@ -442,7 +457,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
);
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
......@@ -454,7 +472,10 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
);
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
......@@ -503,7 +524,10 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
);
std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
......@@ -515,7 +539,10 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
);
at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
......@@ -526,7 +553,10 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
);
std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input,
......
......@@ -74,14 +74,18 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) {
using namespace transformer_engine;
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
return layernorm_fwd_fp8_noalloc(input, weight, bias, eps,
scale, ln_out, amax, scale_inv,
otype, sm_margin, zero_centered_gamma);
otype, sm_margin, zero_centered_gamma,
scale_offset, amax_offset, scale_inv_offset);
}
......@@ -95,35 +99,49 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) {
using namespace transformer_engine;
// Choose kernel implementation
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
// Tensor dimensions
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype,
getDataPtr(amax), getDataPtr(scale),
getDataPtr(scale_inv));
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(),
{N, H},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
// Query workspace sizes
transformer_engine::TensorWrapper workspace, barrier;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
// Allocate workspaces
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
......@@ -136,7 +154,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
// Launch kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
......@@ -155,12 +173,19 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma);
input, weight, bias, eps,
scale, amax, scale_inv,
otype, sm_margin, zero_centered_gamma,
scale_offset, amax_offset, scale_inv_offset);
return out[0];
}
......@@ -273,14 +298,18 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) {
using namespace transformer_engine;
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
return rmsnorm_fwd_fp8_noalloc(input, weight, eps,
scale, ln_out, amax, scale_inv,
otype, sm_margin, zero_centered_gamma);
otype, sm_margin, zero_centered_gamma,
scale_offset, amax_offset, scale_inv_offset);
}
......@@ -293,32 +322,46 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) {
using namespace transformer_engine;
// Choose kernel implementation
const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
// Tensor dimensions
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
DType itype = GetTransformerEngineDType(input.scalar_type());
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype,
getDataPtr(amax), getDataPtr(scale),
getDataPtr(scale_inv));
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(),
{N, H},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
// Query workspace sizes
transformer_engine::TensorWrapper workspace, barrier;
func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(),
rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
// Allocate workspaces
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
......@@ -331,7 +374,7 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
// Launch kernel
func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(),
rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
......@@ -349,12 +392,18 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) {
// This is a specialized version of rmsnorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd_fp8(
input, weight, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma);
input, weight, eps,
scale, amax, scale_inv,
otype, sm_margin, zero_centered_gamma,
scale_offset, amax_offset, scale_inv_offset);
return out[0];
}
......
......@@ -31,25 +31,126 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD");
// Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8");
m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8");
m.def("layernorm_fwd_fp8",
&layernorm_fwd_fp8,
"LN FWD FP8",
py::arg("input"),
py::arg("weight"),
py::arg("bias"),
py::arg("eps"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("sm_margin"),
py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("layernorm_fwd_fp8_noalloc",
&layernorm_fwd_fp8_noalloc,
"LN FWD FP8",
py::arg("input"),
py::arg("weight"),
py::arg("bias"),
py::arg("eps"),
py::arg("scale"),
py::arg("ln_out"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("sm_margin"),
py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD");
m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8");
m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8");
m.def("rmsnorm_fwd_fp8",
&rmsnorm_fwd_fp8,
"RMSNorm FWD FP8",
py::arg("input"),
py::arg("weight"),
py::arg("eps"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("sm_margin"),
py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_fwd_fp8_noalloc",
&rmsnorm_fwd_fp8_noalloc,
"RMSNorm FWD FP8",
py::arg("input"),
py::arg("weight"),
py::arg("eps"),
py::arg("scale"),
py::arg("ln_out"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("sm_margin"),
py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD");
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop,
"Fused Cast + Transpose with noop option");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD");
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad,
"Fused FP8 Transpose + BGRAD");
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU");
m.def("fused_cast_transpose_noop",
&fused_cast_transpose_noop,
"Fused Cast + Transpose with noop option",
py::arg("input"),
py::arg("noop"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("input_cast"),
py::arg("input_transpose"),
py::arg("otype"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad",
&fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD",
py::arg("grad_output"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("fused_fp8_transpose_bgrad",
&fused_fp8_transpose_bgrad,
"Fused FP8 Transpose + BGRAD",
py::arg("grad_output"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("grad_bias_type"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad_dgelu",
&fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU",
py::arg("grad_output"),
py::arg("gelu_input"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose");
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8");
......
......@@ -39,23 +39,42 @@ void fused_cast_transpose_noop(at::Tensor input,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
transformer_engine::DType otype,
int scale_offset,
int amax_offset,
int scale_inv_offset
) {
using namespace transformer_engine;
// Tensor dimensions
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
auto input_cu = makeTransformerEngineTensor(input);
auto noop_cu = makeTransformerEngineTensor(noop);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(),
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(),
{M, N},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(),
{N, M},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
// Launch kernel
nvte_cast_transpose_with_noop(input_cu.data(),
noop_cu.data(),
output_cast_cu.data(),
output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
......@@ -65,47 +84,64 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
transformer_engine::DType otype,
int scale_offset,
int amax_offset,
int scale_inv_offset
) {
using namespace transformer_engine;
// Tensor dimensions
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
// Allocate output tensors
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto grad_output_cast =
allocateTorchTensor(grad_output.size(0),
auto grad_output_cast = allocateTorchTensor(grad_output.size(0),
grad_output.size(1),
DType::kByte);
auto grad_output_transpose =
allocateTorchTensor(grad_output.size(1),
auto grad_output_transpose = allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
if (M == 0 || N == 0)
// Return immediately if tensors are empty
if (M == 0 || N == 0) {
return {grad_bias, grad_output_cast, grad_output_transpose};
}
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(),
{M, N},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(),
scale.data_ptr(), scale_inv.data_ptr());
{N, M},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
// Launch kernel
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
......@@ -119,36 +155,51 @@ std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
transformer_engine::DType grad_bias_type
transformer_engine::DType grad_bias_type,
int scale_offset,
int amax_offset,
int scale_inv_offset
) {
using namespace transformer_engine;
// Tensor dimensions
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type);
auto grad_output_transpose =
allocateTorchTensor(grad_output.size(1),
auto grad_output_transpose = allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(),
{M, N},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(),
scale.data_ptr(), scale_inv.data_ptr());
{N, M},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
// Launch kernel
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
......@@ -162,46 +213,59 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
transformer_engine::DType otype,
int scale_offset,
int amax_offset,
int scale_inv_offset
) {
using namespace transformer_engine;
// Tensor dimensions
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto dgelu =
allocateTorchTensor(grad_output.size(0),
auto dgelu = allocateTorchTensor(grad_output.size(0),
grad_output.size(1),
DType::kByte);
auto dgelu_transpose =
allocateTorchTensor(grad_output.size(1),
auto dgelu_transpose = allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
transformer_engine::TensorWrapper workspace;
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input);
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(),
{M, N},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(),
{N, M},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
// Launch kernel
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
......
......@@ -414,7 +414,10 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
scale_inv,
otype_arg,
sm_margin,
zero_centered_gamma);
zero_centered_gamma,
fp8_tensor, // scale_offset
fp8_tensor, // amax_offset
fp8_tensor); // scale_inv_offset
return output;
}
......@@ -460,7 +463,10 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
scale_inv,
otype_arg,
sm_margin,
zero_centered_gamma);
zero_centered_gamma,
fp8_tensor, // scale_offset
fp8_tensor, // amax_offset
fp8_tensor); // scale_inv_offset
return output;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment