Unverified Commit 90f3c9ad authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Avoid select op in PyTorch extensions (#865)



* Avoid select operation in cast-transpose extension
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid select operation in cast-transpose-dbias extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid select op in LayerNorm and RMSNorm
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter errors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent bbb22395
...@@ -35,13 +35,16 @@ def layernorm_fwd_fp8( ...@@ -35,13 +35,16 @@ def layernorm_fwd_fp8(
weight, weight,
bias, bias,
eps, eps,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
ln_out, ln_out,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
return tex.layernorm_fwd_fp8( return tex.layernorm_fwd_fp8(
...@@ -49,12 +52,15 @@ def layernorm_fwd_fp8( ...@@ -49,12 +52,15 @@ def layernorm_fwd_fp8(
weight, weight,
bias, bias,
eps, eps,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
...@@ -124,25 +130,31 @@ def rmsnorm_fwd_fp8( ...@@ -124,25 +130,31 @@ def rmsnorm_fwd_fp8(
inp, inp,
weight, weight,
eps, eps,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
rmsnorm_out, rmsnorm_out,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
return tex.rmsnorm_fwd_fp8( return tex.rmsnorm_fwd_fp8(
inp, inp,
weight, weight,
eps, eps,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
otype, otype,
sm_margin, sm_margin,
zero_centered_gamma zero_centered_gamma,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
......
...@@ -43,12 +43,15 @@ def fp8_cast_transpose_fused( ...@@ -43,12 +43,15 @@ def fp8_cast_transpose_fused(
tex.fused_cast_transpose_noop( tex.fused_cast_transpose_noop(
inp, inp,
noop_flag, noop_flag,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
cast_out, cast_out,
transpose_out, transpose_out,
otype, otype,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
if return_outputs: if return_outputs:
...@@ -65,10 +68,13 @@ def fp8_cast_transpose_bgrad_fused( ...@@ -65,10 +68,13 @@ def fp8_cast_transpose_bgrad_fused(
"""Cast + Transpose + BGRAD with FP8 output""" """Cast + Transpose + BGRAD with FP8 output"""
return tex.fused_cast_transpose_bgrad( return tex.fused_cast_transpose_bgrad(
inp, inp,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
otype, otype,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
...@@ -82,11 +88,14 @@ def fp8_transpose_bgrad_fused( ...@@ -82,11 +88,14 @@ def fp8_transpose_bgrad_fused(
"""Transpose + BGRAD with FP8 output""" """Transpose + BGRAD with FP8 output"""
return tex.fused_fp8_transpose_bgrad( return tex.fused_fp8_transpose_bgrad(
inp, inp,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
otype, otype,
TE_DType[grad_bias_type], TE_DType[grad_bias_type],
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
...@@ -101,8 +110,11 @@ def fp8_cast_transpose_bgrad_dgelu_fused( ...@@ -101,8 +110,11 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
return tex.fused_cast_transpose_bgrad_dgelu( return tex.fused_cast_transpose_bgrad_dgelu(
grad_output, grad_output,
gelu_input, gelu_input,
fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv[fp8_tensor], fp8_meta_tensor.scale_inv,
otype, otype,
scale_offset=int(fp8_tensor),
amax_offset=int(fp8_tensor),
scale_inv_offset=int(fp8_tensor),
) )
...@@ -138,10 +138,15 @@ at::Tensor allocateTorchTensor(int M, ...@@ -138,10 +138,15 @@ at::Tensor allocateTorchTensor(int M,
at::CUDA(GetATenDType(dtype))); at::CUDA(GetATenDType(dtype)));
} }
void *getDataPtr(at::Tensor t) { void* getDataPtr(at::Tensor tensor, int offset) {
if (t.numel() > 0) { void* dptr = nullptr;
return t.data_ptr(); if (tensor.numel() > 0) {
} else { dptr = tensor.data_ptr();
return nullptr; }
if (dptr != nullptr && offset != 0) {
char* char_ptr = reinterpret_cast<char*>(dptr);
char_ptr += offset * tensor.element_size();
dptr = reinterpret_cast<void*>(char_ptr);
} }
return dptr;
} }
...@@ -186,6 +186,6 @@ at::Tensor allocateTorchTensor(int M, ...@@ -186,6 +186,6 @@ at::Tensor allocateTorchTensor(int M,
transformer_engine::DType dtype transformer_engine::DType dtype
); );
void *getDataPtr(at::Tensor t); void* getDataPtr(at::Tensor tensor, int offset = 0);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
...@@ -263,7 +263,10 @@ void fused_cast_transpose_noop(at::Tensor input, ...@@ -263,7 +263,10 @@ void fused_cast_transpose_noop(at::Tensor input,
at::Tensor scale_inv, at::Tensor scale_inv,
at::Tensor input_cast, at::Tensor input_cast,
at::Tensor input_transpose, at::Tensor input_transpose,
transformer_engine::DType otype transformer_engine::DType otype,
int scale_offset = 0,
int amax_offset = 0,
int scale_inv_offset = 0
); );
...@@ -271,7 +274,10 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, ...@@ -271,7 +274,10 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype transformer_engine::DType otype,
int scale_offset = 0,
int amax_offset = 0,
int scale_inv_offset = 0
); );
...@@ -280,7 +286,10 @@ std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output, ...@@ -280,7 +286,10 @@ std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
transformer_engine::DType grad_bias_type transformer_engine::DType grad_bias_type,
int scale_offset = 0,
int amax_offset = 0,
int scale_inv_offset = 0
); );
...@@ -289,7 +298,10 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, ...@@ -289,7 +298,10 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype transformer_engine::DType otype,
int scale_offset = 0,
int amax_offset = 0,
int scale_inv_offset = 0
); );
...@@ -429,7 +441,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -429,7 +441,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
); );
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input, std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
...@@ -442,7 +457,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input, ...@@ -442,7 +457,10 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
); );
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
...@@ -454,7 +472,10 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -454,7 +472,10 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
); );
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
...@@ -503,7 +524,10 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input, ...@@ -503,7 +524,10 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
); );
std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
...@@ -515,7 +539,10 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, ...@@ -515,7 +539,10 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
); );
at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
...@@ -526,7 +553,10 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -526,7 +553,10 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset = 0,
const int amax_offset = 0,
const int scale_inv_offset = 0
); );
std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input, std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input,
......
...@@ -74,14 +74,18 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -74,14 +74,18 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, return layernorm_fwd_fp8_noalloc(input, weight, bias, eps,
scale, ln_out, amax, scale_inv, scale, ln_out, amax, scale_inv,
otype, sm_margin, zero_centered_gamma); otype, sm_margin, zero_centered_gamma,
scale_offset, amax_offset, scale_inv_offset);
} }
...@@ -95,35 +99,49 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input, ...@@ -95,35 +99,49 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
// Choose kernel implementation
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
// Tensor dimensions
size_t N = static_cast<size_t>(input.size(0)); size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1)); size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type()); // Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight); auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias); auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(),
getDataPtr(amax), getDataPtr(scale), {N, H},
getDataPtr(scale_inv)); otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto mu_cu = makeTransformerEngineTensor(mu); auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma); auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config // Query workspace sizes
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; transformer_engine::TensorWrapper workspace, barrier;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
// Fill workspace and barrier // Allocate workspaces
auto workspace_data = allocateSpace(workspace.shape(), auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype()); workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(), auto barrier_data = allocateSpace(barrier.shape(),
...@@ -136,7 +154,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input, ...@@ -136,7 +154,7 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
barrier.shape(), barrier.shape(),
barrier.dtype()); barrier.dtype());
// Actual call to fwd kernel // Launch kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
...@@ -155,12 +173,19 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -155,12 +173,19 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) { ) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference, // This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output. // which only returns the normalized output.
std::vector<at::Tensor> out = layernorm_fwd_fp8( std::vector<at::Tensor> out = layernorm_fwd_fp8(
input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma); input, weight, bias, eps,
scale, amax, scale_inv,
otype, sm_margin, zero_centered_gamma,
scale_offset, amax_offset, scale_inv_offset);
return out[0]; return out[0];
} }
...@@ -273,14 +298,18 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input, ...@@ -273,14 +298,18 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype)));
return rmsnorm_fwd_fp8_noalloc(input, weight, eps, return rmsnorm_fwd_fp8_noalloc(input, weight, eps,
scale, ln_out, amax, scale_inv, scale, ln_out, amax, scale_inv,
otype, sm_margin, zero_centered_gamma); otype, sm_margin, zero_centered_gamma,
scale_offset, amax_offset, scale_inv_offset);
} }
...@@ -293,32 +322,46 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, ...@@ -293,32 +322,46 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
// Choose kernel implementation
const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd;
// Tensor dimensions
size_t N = static_cast<size_t>(input.size(0)); size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1)); size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type()); // Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
DType itype = GetTransformerEngineDType(input.scalar_type());
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight); auto gamma_cu = makeTransformerEngineTensor(weight);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(),
getDataPtr(amax), getDataPtr(scale), {N, H},
getDataPtr(scale_inv)); otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto rsigma_cu = makeTransformerEngineTensor(rsigma); auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config // Query workspace sizes
const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd; transformer_engine::TensorWrapper workspace, barrier;
func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(),
rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data()); workspace.data(), barrier.data());
// Fill workspace and barrier // Allocate workspaces
auto workspace_data = allocateSpace(workspace.shape(), auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype()); workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(), auto barrier_data = allocateSpace(barrier.shape(),
...@@ -331,7 +374,7 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, ...@@ -331,7 +374,7 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
barrier.shape(), barrier.shape(),
barrier.dtype()); barrier.dtype());
// Actual call to fwd kernel // Launch kernel
func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(),
rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
...@@ -349,12 +392,18 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, ...@@ -349,12 +392,18 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
const int sm_margin, const int sm_margin,
const bool zero_centered_gamma const bool zero_centered_gamma,
const int scale_offset,
const int amax_offset,
const int scale_inv_offset
) { ) {
// This is a specialized version of rmsnorm_fwd_fp8, optimized for inference, // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference,
// which only returns the normalized output. // which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd_fp8( std::vector<at::Tensor> out = rmsnorm_fwd_fp8(
input, weight, eps, scale, amax, scale_inv, otype, sm_margin, zero_centered_gamma); input, weight, eps,
scale, amax, scale_inv,
otype, sm_margin, zero_centered_gamma,
scale_offset, amax_offset, scale_inv_offset);
return out[0]; return out[0];
} }
......
...@@ -31,25 +31,126 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -31,25 +31,126 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD"); "Scaled Bottom-Right Corner Aligned Masked Softmax BWD");
// Other granular functions // Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8"); m.def("layernorm_fwd_fp8",
m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8"); &layernorm_fwd_fp8,
"LN FWD FP8",
py::arg("input"),
py::arg("weight"),
py::arg("bias"),
py::arg("eps"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("sm_margin"),
py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("layernorm_fwd_fp8_noalloc",
&layernorm_fwd_fp8_noalloc,
"LN FWD FP8",
py::arg("input"),
py::arg("weight"),
py::arg("bias"),
py::arg("eps"),
py::arg("scale"),
py::arg("ln_out"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("sm_margin"),
py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD"); m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD"); m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD"); m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD");
m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8"); m.def("rmsnorm_fwd_fp8",
m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8"); &rmsnorm_fwd_fp8,
"RMSNorm FWD FP8",
py::arg("input"),
py::arg("weight"),
py::arg("eps"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("sm_margin"),
py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_fwd_fp8_noalloc",
&rmsnorm_fwd_fp8_noalloc,
"RMSNorm FWD FP8",
py::arg("input"),
py::arg("weight"),
py::arg("eps"),
py::arg("scale"),
py::arg("ln_out"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("sm_margin"),
py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD"); m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD");
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD"); m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD");
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD"); m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, m.def("fused_cast_transpose_noop",
"Fused Cast + Transpose with noop option"); &fused_cast_transpose_noop,
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose with noop option",
"Fused Cast + Transpose + BGRAD"); py::arg("input"),
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, py::arg("noop"),
"Fused FP8 Transpose + BGRAD"); py::arg("scale"),
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu, py::arg("amax"),
"Fused Cast + Transpose + BGRAD + DGELU"); py::arg("scale_inv"),
py::arg("input_cast"),
py::arg("input_transpose"),
py::arg("otype"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad",
&fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD",
py::arg("grad_output"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("fused_fp8_transpose_bgrad",
&fused_fp8_transpose_bgrad,
"Fused FP8 Transpose + BGRAD",
py::arg("grad_output"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("grad_bias_type"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad_dgelu",
&fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU",
py::arg("grad_output"),
py::arg("gelu_input"),
py::arg("scale"),
py::arg("amax"),
py::arg("scale_inv"),
py::arg("otype"),
py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0,
py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose"); "Fused Multi-tensor Cast + Transpose");
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8"); m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8");
......
...@@ -39,23 +39,42 @@ void fused_cast_transpose_noop(at::Tensor input, ...@@ -39,23 +39,42 @@ void fused_cast_transpose_noop(at::Tensor input,
at::Tensor scale_inv, at::Tensor scale_inv,
at::Tensor input_cast, at::Tensor input_cast,
at::Tensor input_transpose, at::Tensor input_transpose,
transformer_engine::DType otype transformer_engine::DType otype,
int scale_offset,
int amax_offset,
int scale_inv_offset
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
// Tensor dimensions
size_t M = static_cast<size_t>(input.size(0)); size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1)); size_t N = static_cast<size_t>(input.size(1));
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input);
auto noop_cu = makeTransformerEngineTensor(noop); auto noop_cu = makeTransformerEngineTensor(noop);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(),
amax.data_ptr(), scale.data_ptr(), {M, N},
scale_inv.data_ptr()); otype,
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, amax_dptr,
amax.data_ptr(), scale.data_ptr(), scale_dptr,
scale_inv.data_ptr()); scale_inv_dptr);
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(),
nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(), {N, M},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
// Launch kernel
nvte_cast_transpose_with_noop(input_cu.data(),
noop_cu.data(),
output_cast_cu.data(),
output_transpose_cu.data(), output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
...@@ -65,47 +84,64 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, ...@@ -65,47 +84,64 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype transformer_engine::DType otype,
int scale_offset,
int amax_offset,
int scale_inv_offset
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
// Tensor dimensions
size_t M = static_cast<size_t>(grad_output.size(0)); size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1)); size_t N = static_cast<size_t>(grad_output.size(1));
// Allocate output tensors
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto grad_output_cast = auto grad_output_cast = allocateTorchTensor(grad_output.size(0),
allocateTorchTensor(grad_output.size(0),
grad_output.size(1), grad_output.size(1),
DType::kByte); DType::kByte);
auto grad_output_transpose = auto grad_output_transpose = allocateTorchTensor(grad_output.size(1),
allocateTorchTensor(grad_output.size(1),
grad_output.size(0), grad_output.size(0),
DType::kByte); DType::kByte);
if (M == 0 || N == 0) // Return immediately if tensors are empty
if (M == 0 || N == 0) {
return {grad_bias, grad_output_cast, grad_output_transpose}; return {grad_bias, grad_output_cast, grad_output_transpose};
}
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
auto input_cu = makeTransformerEngineTensor(grad_output); auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N}, auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(),
otype, amax.data_ptr(), scale.data_ptr(), {M, N},
scale_inv.data_ptr()); otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(), auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(), {N, M},
scale.data_ptr(), scale_inv.data_ptr()); otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto dbias_cu = makeTransformerEngineTensor(grad_bias); auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(), workspace.shape(),
workspace.dtype()); workspace.dtype());
// Launch kernel
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
...@@ -119,36 +155,51 @@ std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output, ...@@ -119,36 +155,51 @@ std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype, transformer_engine::DType otype,
transformer_engine::DType grad_bias_type transformer_engine::DType grad_bias_type,
int scale_offset,
int amax_offset,
int scale_inv_offset
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
// Tensor dimensions
size_t M = static_cast<size_t>(grad_output.size(0)); size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1)); size_t N = static_cast<size_t>(grad_output.size(1));
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type); auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type);
auto grad_output_transpose = auto grad_output_transpose = allocateTorchTensor(grad_output.size(1),
allocateTorchTensor(grad_output.size(1),
grad_output.size(0), grad_output.size(0),
DType::kByte); DType::kByte);
auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N}, auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(),
otype, amax.data_ptr(), scale.data_ptr(), {M, N},
scale_inv.data_ptr()); otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(), auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(), {N, M},
scale.data_ptr(), scale_inv.data_ptr()); otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto dbias_cu = makeTransformerEngineTensor(grad_bias); auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(), workspace.shape(),
workspace.dtype()); workspace.dtype());
// Launch kernel
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream()); workspace.data(), at::cuda::getCurrentCUDAStream());
...@@ -162,46 +213,59 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, ...@@ -162,46 +213,59 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor scale, at::Tensor scale,
at::Tensor amax, at::Tensor amax,
at::Tensor scale_inv, at::Tensor scale_inv,
transformer_engine::DType otype transformer_engine::DType otype,
int scale_offset,
int amax_offset,
int scale_inv_offset
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
// Tensor dimensions
size_t M = static_cast<size_t>(grad_output.size(0)); size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1)); size_t N = static_cast<size_t>(grad_output.size(1));
// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto dgelu = auto dgelu = allocateTorchTensor(grad_output.size(0),
allocateTorchTensor(grad_output.size(0),
grad_output.size(1), grad_output.size(1),
DType::kByte); DType::kByte);
auto dgelu_transpose = auto dgelu_transpose = allocateTorchTensor(grad_output.size(1),
allocateTorchTensor(grad_output.size(1),
grad_output.size(0), grad_output.size(0),
DType::kByte); DType::kByte);
transformer_engine::TensorWrapper workspace;
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input); auto gelu_input_cu = makeTransformerEngineTensor(gelu_input);
auto input_cu = makeTransformerEngineTensor(grad_output); auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N}, auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(),
otype, amax.data_ptr(), scale.data_ptr(), {M, N},
scale_inv.data_ptr()); otype,
auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M}, amax_dptr,
otype, amax.data_ptr(), scale.data_ptr(), scale_dptr,
scale_inv.data_ptr()); scale_inv_dptr);
auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(),
{N, M},
otype,
amax_dptr,
scale_dptr,
scale_inv_dptr);
auto dbias_cu = makeTransformerEngineTensor(grad_bias); auto dbias_cu = makeTransformerEngineTensor(grad_bias);
// Query workspace size and allocate workspace
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(), cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(), dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(), workspace.shape(),
workspace.dtype()); workspace.dtype());
// Launch kernel
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(), cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(), dbias_cu.data(), workspace.data(),
......
...@@ -414,7 +414,10 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -414,7 +414,10 @@ at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input,
scale_inv, scale_inv,
otype_arg, otype_arg,
sm_margin, sm_margin,
zero_centered_gamma); zero_centered_gamma,
fp8_tensor, // scale_offset
fp8_tensor, // amax_offset
fp8_tensor); // scale_inv_offset
return output; return output;
} }
...@@ -460,7 +463,10 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, ...@@ -460,7 +463,10 @@ at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
scale_inv, scale_inv,
otype_arg, otype_arg,
sm_margin, sm_margin,
zero_centered_gamma); zero_centered_gamma,
fp8_tensor, // scale_offset
fp8_tensor, // amax_offset
fp8_tensor); // scale_inv_offset
return output; return output;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment