Commit 544dd14b authored by Przemek Tredak's avatar Przemek Tredak
Browse files

Update main branch with TE 2.0 code, update version to 2.1.0.dev0


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent e5369541
...@@ -6,10 +6,29 @@ ...@@ -6,10 +6,29 @@
#include "extensions.h" #include "extensions.h"
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, namespace transformer_engine::pytorch {
std::pair<TensorWrapper, py::object> createOutputTensor(const NVTEShape &shape, DType dtype,
py::handle quantizer) {
std::vector<size_t> shape_vec;
for (int i = 0; i < shape.ndim; i++) {
size_t t = shape.data[i];
shape_vec.push_back(t);
}
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape_vec, dtype);
}
std::pair<TensorWrapper, py::object> createOutputTensor(std::vector<size_t> &shape, DType dtype,
py::handle quantizer) {
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
return my_quantizer->create_tensor(shape, dtype);
}
} // namespace transformer_engine::pytorch
std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &mu, const at::Tensor &rsigma,
const at::Tensor &gamma, const int sm_margin, const at::Tensor &gamma, const int sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous(); const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous(); const auto &x_ = x.contiguous();
const auto &mu_ = mu.contiguous(); const auto &mu_ = mu.contiguous();
...@@ -47,61 +66,57 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -47,61 +66,57 @@ std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {dx, dgamma, dbeta}; return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)};
} }
std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
const at::Tensor &bias, float eps, at::Tensor scale, float eps, py::object ln_out, py::handle quantizer,
at::Tensor amax, at::Tensor scale_inv, DType out_dtype, const int sm_margin,
transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma) {
const bool zero_centered_gamma, const int scale_offset, using namespace transformer_engine::pytorch;
const int amax_offset, const int scale_inv_offset) {
using namespace transformer_engine; using namespace transformer_engine;
const auto &input_ = input.contiguous(); auto none = py::none();
const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none);
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype))); TensorWrapper bias_tensor;
return layernorm_fwd_fp8_noalloc(input_, weight, bias, eps, scale, ln_out, amax, scale_inv, otype, MaybeTensor bias_grad = std::nullopt;
sm_margin, zero_centered_gamma, scale_offset, amax_offset, if (bias.has_value()) {
scale_inv_offset); bias_tensor = makeTransformerEngineTensor(*bias);
} }
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps,
at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma,
const int scale_offset, const int amax_offset, const int scale_inv_offset) {
using namespace transformer_engine;
const auto &input_ = input.contiguous();
const auto &weight_ = weight.contiguous();
const auto &bias_ = bias.contiguous();
// Tensor dimensions // Tensor dimensions
size_t N = static_cast<size_t>(input.size(0)); size_t N = static_cast<size_t>(input_tensor.size(0));
size_t H = static_cast<size_t>(input.size(1)); size_t H = static_cast<size_t>(input_tensor.size(1));
std::vector<size_t> size = {N, H};
// Get pointers for FP8 scale, amax, scale-inverse
void *scale_dptr = getDataPtr(scale, scale_offset);
void *amax_dptr = getDataPtr(amax, amax_offset);
void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors // Construct Transformer Engine tensors
DType itype = GetTransformerEngineDType(input.scalar_type()); at::Tensor mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); at::Tensor rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input_); TensorWrapper ln_out_tensor;
auto gamma_cu = makeTransformerEngineTensor(weight_); std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
auto beta_cu = makeTransformerEngineTensor(bias_); py::object ln_output;
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr,
scale_inv_dptr); if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
auto mu_cu = makeTransformerEngineTensor(mu); // Use high precision output from normalization
auto rsigma_cu = makeTransformerEngineTensor(rsigma); NoneQuantizer q{none};
std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, out_dtype);
} else {
if (ln_out.is_none()) {
std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype);
} else {
ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
}
}
TensorWrapper mu_cu = makeTransformerEngineTensor(mu);
TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma);
// Query workspace sizes // Query workspace sizes
transformer_engine::TensorWrapper workspace; transformer_engine::TensorWrapper workspace;
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps,
mu_cu.data(), rsigma_cu.data(), workspace.data(), ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
...@@ -111,66 +126,30 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc( ...@@ -111,66 +126,30 @@ std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel // Launch kernel
nvte_layernorm_fwd(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps,
mu_cu.data(), rsigma_cu.data(), workspace.data(), ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {ln_out, mu, rsigma}; if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
} TensorWrapper cast_out_tensor;
if (ln_out.is_none()) {
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype);
const at::Tensor &bias, float eps, at::Tensor scale, } else {
at::Tensor amax, at::Tensor scale_inv, cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
transformer_engine::DType otype, const int sm_margin, }
const bool zero_centered_gamma, const int scale_offset,
const int amax_offset, const int scale_inv_offset
) {
// This is a specialized version of layernorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out =
layernorm_fwd_fp8(input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin,
zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset);
return out[0];
}
std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, float eps, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type());
const auto &input_ = input.contiguous();
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype)));
return layernorm_fwd_noalloc(input_, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma);
}
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, at::Tensor ln_out, float eps,
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type());
return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(), ln_out, at::Tensor(), nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr,
at::Tensor(), itype, sm_margin, zero_centered_gamma); at::cuda::getCurrentCUDAStream());
} }
at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, return {ln_out, py::cast(mu), py::cast(rsigma)};
const at::Tensor &bias, float eps, const int sm_margin,
const bool zero_centered_gamma) {
// This is a specialized version of layernorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out =
layernorm_fwd(input, weight, bias, eps, sm_margin, zero_centered_gamma);
return out[0];
} }
std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
const at::Tensor &rsigma, const at::Tensor &gamma, const at::Tensor &rsigma, const at::Tensor &gamma,
const int sm_margin, const bool zero_centered_gamma) { const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
const auto &dz_ = dz.contiguous(); const auto &dz_ = dz.contiguous();
const auto &x_ = x.contiguous(); const auto &x_ = x.contiguous();
const auto &rsigma_ = rsigma.contiguous(); const auto &rsigma_ = rsigma.contiguous();
...@@ -204,57 +183,48 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -204,57 +183,48 @@ std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {dx, dgamma}; return {py::cast(dx), py::cast(dgamma)};
} }
std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
float eps, at::Tensor scale, at::Tensor amax, py::object ln_out, py::handle quantizer,
at::Tensor scale_inv, transformer_engine::DType otype, transformer_engine::DType otype, const int sm_margin,
const int sm_margin, const bool zero_centered_gamma, const bool zero_centered_gamma) {
const int scale_offset, const int amax_offset, using namespace transformer_engine::pytorch;
const int scale_inv_offset) {
using namespace transformer_engine; using namespace transformer_engine;
const auto &input_ = input.contiguous(); auto none = py::none();
const auto &weight_ = weight.contiguous(); const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none);
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype)));
return rmsnorm_fwd_fp8_noalloc(input_, weight_, eps, scale, ln_out, amax, scale_inv, otype,
sm_margin, zero_centered_gamma, scale_offset, amax_offset,
scale_inv_offset);
}
std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const at::Tensor &weight,
float eps, at::Tensor scale, at::Tensor ln_out,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin, const bool zero_centered_gamma,
const int scale_offset, const int amax_offset,
const int scale_inv_offset) {
using namespace transformer_engine;
// Tensor dimensions // Tensor dimensions
size_t N = static_cast<size_t>(input.size(0)); size_t N = static_cast<size_t>(input_tensor.shape().data[0]);
size_t H = static_cast<size_t>(input.size(1)); size_t H = static_cast<size_t>(input_tensor.shape().data[1]);
// Get pointers for FP8 scale, amax, scale-inverse
void *scale_dptr = getDataPtr(scale, scale_offset);
void *amax_dptr = getDataPtr(amax, amax_offset);
void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);
// Construct Transformer Engine tensors // Construct Transformer Engine tensors
DType itype = GetTransformerEngineDType(input.scalar_type());
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input); std::vector<size_t> size = {N, H};
auto gamma_cu = makeTransformerEngineTensor(weight); TensorWrapper ln_out_tensor;
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
scale_inv_dptr); py::object ln_output;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
// Use high precision output from normalization
NoneQuantizer q{none};
std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, otype);
} else {
if (ln_out.is_none()) {
std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype);
} else {
ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
}
}
auto rsigma_cu = makeTransformerEngineTensor(rsigma); auto rsigma_cu = makeTransformerEngineTensor(rsigma);
// Query workspace sizes // Query workspace sizes
transformer_engine::TensorWrapper workspace; transformer_engine::TensorWrapper workspace;
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(),
workspace.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
...@@ -264,55 +234,22 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a ...@@ -264,55 +234,22 @@ std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const a
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel // Launch kernel
nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(),
workspace.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
return {ln_out, rsigma}; if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
} TensorWrapper cast_out_tensor;
if (ln_out.is_none()) {
at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps, std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype);
at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, } else {
transformer_engine::DType otype, const int sm_margin, cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
const bool zero_centered_gamma, const int scale_offset, }
const int amax_offset, const int scale_inv_offset) {
// This is a specialized version of rmsnorm_fwd_fp8, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out =
rmsnorm_fwd_fp8(input, weight, eps, scale, amax, scale_inv, otype, sm_margin,
zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset);
return out[0];
}
std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps,
const int sm_margin, const bool zero_centered_gamma) {
using namespace transformer_engine;
const auto &input_ = input.contiguous();
const auto &weight_ = weight.contiguous();
DType itype = GetTransformerEngineDType(input.scalar_type()); nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr,
auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype))); at::cuda::getCurrentCUDAStream());
}
return rmsnorm_fwd_noalloc(input_, weight_, ln_out, eps, sm_margin, zero_centered_gamma); return {ln_out, py::none(), py::cast(rsigma)};
}
std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
at::Tensor ln_out, float eps, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine;
DType itype = GetTransformerEngineDType(input.scalar_type());
return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(), ln_out, at::Tensor(),
at::Tensor(), itype, sm_margin, zero_centered_gamma);
}
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps,
const int sm_margin, const bool zero_centered_gamma) {
// This is a specialized version of rmsnorm_fwd, optimized for inference,
// which only returns the normalized output.
std::vector<at::Tensor> out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma);
return out[0];
} }
...@@ -10,6 +10,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, ...@@ -10,6 +10,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list, std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list) { std::vector<size_t> padded_input_row_list) {
using namespace transformer_engine; using namespace transformer_engine;
using namespace transformer_engine::pytorch;
NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(),
"Number of input row list and padded row list must match."); "Number of input row list and padded row list must match.");
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd( std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices,
int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) { int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num) {
using namespace transformer_engine::pytorch;
const int num_tokens = input.size(0); const int num_tokens = input.size(0);
int num_cols = input.size(1); int num_cols = input.size(1);
const int topK = indices.size(1); const int topK = indices.size(1);
...@@ -96,6 +97,7 @@ at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dty ...@@ -96,6 +97,7 @@ at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dty
at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
int64_t topK) { int64_t topK) {
using namespace transformer_engine::pytorch;
int num_cols = input.size(1); int num_cols = input.size(1);
// Activations type // Activations type
...@@ -129,6 +131,7 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d ...@@ -129,6 +131,7 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
const transformer_engine::DType dtype, const transformer_engine::DType dtype,
at::Tensor row_id_map, at::Tensor prob) { at::Tensor row_id_map, at::Tensor prob) {
using namespace transformer_engine::pytorch;
const int topK = (prob.numel() > 0) ? prob.size(1) : 1; const int topK = (prob.numel() > 0) ? prob.size(1) : 1;
const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0);
int num_cols = input_bwd.size(1); int num_cols = input_bwd.size(1);
......
...@@ -4,14 +4,131 @@ ...@@ -4,14 +4,131 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "pybind.h"
#include <Python.h>
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <stdexcept>
#include "../common.h"
#include "../extensions.h" #include "../extensions.h"
#include "common.h"
namespace transformer_engine::pytorch {
PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *Float8TensorBasePythonClass = nullptr;
PyTypeObject *Float8QuantizerClass = nullptr;
PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *MXFP8TensorBasePythonClass = nullptr;
PyTypeObject *MXFP8QuantizerClass = nullptr;
void init_float8_extension() {
if (Float8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Float8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor"));
auto fp8_base_module =
py::module_::import("transformer_engine.pytorch.tensor._internal.float8_tensor_base");
Float8TensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorBase"));
NVTE_CHECK(Float8TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch Float8 extension.");
}
void init_mxfp8_extension() {
if (MXFP8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor");
MXFP8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer"));
MXFP8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor"));
auto fp8_base_module =
py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base");
MXFP8TensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorBase"));
NVTE_CHECK(MXFP8TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch MXFP8 extension.");
}
void init_extension() {
init_float8_extension();
init_mxfp8_extension();
}
} // namespace transformer_engine::pytorch
#include "common/util/pybind_helper.h" #include "common/util/pybind_helper.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m)
m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"),
py::arg("output") = py::none(), py::arg("noop") = py::none());
m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"),
py::arg("otype"));
m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize,
"Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer"));
m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)",
py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"),
py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"),
py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"),
py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"),
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false);
m.def("rowwise_swizzle", &rowwise_swizzle, "Swizzle rowwise scale inverses.",
py::call_guard<py::gil_scoped_release>());
m.def("columnwise_swizzle", &columnwise_swizzle, "Swizzle columnwise scale inverses.",
py::call_guard<py::gil_scoped_release>());
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"),
py::arg("quantizer"));
m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_drelu", transformer_engine::pytorch::dbias_drelu, "DReLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_dqgelu", transformer_engine::pytorch::dbias_dqgelu, "DQGeLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
m.def("dbias_dsrelu", transformer_engine::pytorch::dbias_dsrelu,
"DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"),
py::arg("quantizer"));
// Permutation functions // Permutation functions
m.def("moe_permute_fwd", moe_permute_fwd); m.def("moe_permute_fwd", moe_permute_fwd);
...@@ -42,116 +159,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -42,116 +159,24 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// Other granular functions // Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8", m.def("layernorm_fwd", &layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"),
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"), py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"),
py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("sm_margin"), py::arg("zero_centered_gamma"));
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), m.def("layernorm_bwd", &layernorm_bwd, "Backward of LayerNorm");
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"),
m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8", py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"),
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"), py::arg("zero_centered_gamma"));
py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm");
py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"), m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose",
py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype"));
py::arg("scale_inv_offset") = 0);
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD", py::call_guard<py::gil_scoped_release>()); m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD", py::call_guard<py::gil_scoped_release>());
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD",
py::call_guard<py::gil_scoped_release>());
m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"),
py::arg("sm_margin"), py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("weight"),
py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), py::arg("scale_inv"),
py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD", py::call_guard<py::gil_scoped_release>());
m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD", py::call_guard<py::gil_scoped_release>());
m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD",
py::call_guard<py::gil_scoped_release>());
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose",
py::call_guard<py::gil_scoped_release>());
m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop,
"Cast + Transpose with noop option", py::call_guard<py::gil_scoped_release>(),
py::arg("input"), py::arg("noop"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"),
py::arg("input_cast"), py::arg("input_transpose"), py::arg("otype"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD",
py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("scale"),
py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, "Fused FP8 Transpose + BGRAD",
py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("scale"),
py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("grad_bias_type"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu,
"Fused Cast + Transpose + BGRAD + DGELU", py::call_guard<py::gil_scoped_release>(),
py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose,
"Fused SwiGLU backward + FP8 cast + FP8 transpose",
py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("input"),
py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc,
"Fused Multi-tensor Cast + Transpose with allocating output tensors",
py::call_guard<py::gil_scoped_release>());
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard<py::gil_scoped_release>(),
py::arg("input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8",
py::call_guard<py::gil_scoped_release>(), py::arg("input"), py::arg("scale"),
py::arg("output"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"),
py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard<py::gil_scoped_release>(),
py::arg("input"), py::arg("scale_inv"), py::arg("itype"), py::arg("otype"),
py::arg("scale_inv_offset") = 0);
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think
m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM");
m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed QKV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed QKV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked,
"Fused Attention FP8/BF16/FP16 FWD with packed KV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked,
"Fused Attention FP8/BF16/FP16 BWD with packed KV",
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_fwd", &fused_attn_fwd, m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V", "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
py::call_guard<py::gil_scoped_release>());
m.def("fused_attn_bwd", &fused_attn_bwd, m.def("fused_attn_bwd", &fused_attn_bwd,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V", "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V");
py::call_guard<py::gil_scoped_release>()); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"),
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
py::call_guard<py::gil_scoped_release>());
m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O",
py::call_guard<py::gil_scoped_release>());
m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop,
"Transpose with FP8 I/O with noop option.", py::call_guard<py::gil_scoped_release>());
m.def("gelu", &gelu, "GeLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("relu", &relu, "ReLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("geglu", &geglu, "GeGLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("reglu", &reglu, "ReGLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("swiglu", &swiglu, "SwiGLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("qgelu", &qgelu, "QuickGELU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("srelu", &srelu, "Squared ReLU with FP8 output", py::call_guard<py::gil_scoped_release>());
m.def("dgelu", &dgelu, "Backward of GeLU", py::call_guard<py::gil_scoped_release>());
m.def("drelu", &drelu, "Backward of ReLU", py::call_guard<py::gil_scoped_release>());
m.def("dgeglu", &dgeglu, "Backward of GeGLU", py::call_guard<py::gil_scoped_release>());
m.def("dreglu", &dreglu, "Backward of ReGLU", py::call_guard<py::gil_scoped_release>());
m.def("dswiglu", &dswiglu, "Backward of SwiGLU", py::call_guard<py::gil_scoped_release>());
m.def("dqgelu", &dqgelu, "Backward of QuickGELU", py::call_guard<py::gil_scoped_release>());
m.def("dsrelu", &dsrelu, "Backward of Squared ReLU", py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention",
...@@ -233,30 +258,30 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -233,30 +258,30 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
// Data structures // Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta") py::class_<transformer_engine::pytorch::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale) .def_readwrite("scale", &transformer_engine::pytorch::FP8TensorMeta::scale)
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("scale_inv", &transformer_engine::pytorch::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); .def_readwrite("amax_history", &transformer_engine::pytorch::FP8TensorMeta::amax_history);
py::enum_<transformer_engine::FP8FwdTensors>(m, "FP8FwdTensors") py::enum_<transformer_engine::pytorch::FP8FwdTensors>(m, "FP8FwdTensors")
.value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) .value("GEMM1_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_INPUT)
.value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) .value("GEMM1_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_WEIGHT)
.value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) .value("GEMM1_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_OUTPUT)
.value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) .value("GEMM2_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_INPUT)
.value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) .value("GEMM2_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_WEIGHT)
.value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) .value("GEMM2_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_OUTPUT)
.value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) .value("GEMM3_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_INPUT)
.value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) .value("GEMM3_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_WEIGHT)
.value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); .value("GEMM3_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_OUTPUT);
py::enum_<transformer_engine::FP8BwdTensors>(m, "FP8BwdTensors") py::enum_<transformer_engine::pytorch::FP8BwdTensors>(m, "FP8BwdTensors")
.value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) .value("GRAD_OUTPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT1)
.value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) .value("GRAD_INPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT1)
.value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) .value("GRAD_OUTPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT2)
.value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) .value("GRAD_INPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT2)
.value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) .value("GRAD_OUTPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT3)
.value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); .value("GRAD_INPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT3);
py::class_<CommOverlapHelper>(m, "CommOverlapHelper") py::class_<CommOverlapHelper>(m, "CommOverlapHelper")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>()) .def(py::init<>(), py::call_guard<py::gil_scoped_release>())
...@@ -265,54 +290,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -265,54 +290,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>(), py::arg("world_group"), py::call_guard<py::gil_scoped_release>(), py::arg("world_group"),
py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none());
py::class_<CommOverlap>(m, "CommOverlap") py::class_<CommOverlap, std::shared_ptr<CommOverlap>, transformer_engine::CommOverlapBase,
transformer_engine::CommOverlapCore>(m, "CommOverlap")
.def(py::init<const std::vector<size_t> &, at::ScalarType, CommOverlapHelper *, int, int, int, .def(py::init<const std::vector<size_t> &, at::ScalarType, CommOverlapHelper *, int, int, int,
int, int, bool, bool>(), int, int, int, int, bool, bool, bool>(),
py::call_guard<py::gil_scoped_release>(), py::arg("buffer_shape"), py::call_guard<py::gil_scoped_release>(), py::arg("buffer_shape"),
py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"),
py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS,
py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16, py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0,
py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false) py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true,
.def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard<py::gil_scoped_release>()) py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false)
.def("split_overlap_rs", &CommOverlap::split_overlap_rs, .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::call_guard<py::gil_scoped_release>()) py::arg("quantizer"), py::arg("local_chunk") = false)
.def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs, .def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"),
py::call_guard<py::gil_scoped_release>()) py::arg("local_chunk") = false, py::arg("shape") = std::nullopt)
.def("copy_input_to_ubuf", &CommOverlap::copy_input_to_ubuf, .def("set_buffer_params", &CommOverlap::set_buffer_params);
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &CommOverlap::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard<py::gil_scoped_release>());
py::class_<CommOverlapP2P>(m, "CommOverlapP2P") py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
m, "CommOverlapP2P")
.def(py::init<const std::vector<size_t> &, at::ScalarType, CommOverlapHelper *, int, .def(py::init<const std::vector<size_t> &, at::ScalarType, CommOverlapHelper *, int,
transformer_engine::CommOverlapType, int, int, int, bool, bool, bool, bool>(), transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool,
bool>(),
py::call_guard<py::gil_scoped_release>(), py::arg("buffer_shape"), py::call_guard<py::gil_scoped_release>(), py::arg("buffer_shape"),
py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"),
py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1,
py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1,
py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false,
.def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag, py::arg("use_ce") = true, py::arg("aggregate") = false)
py::call_guard<py::gil_scoped_release>()) .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
.def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs, py::arg("quantizer"), py::arg("local_chunk") = false)
py::call_guard<py::gil_scoped_release>()) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"),
.def("atomic_gemm_overlap_ag_p2p", &CommOverlapP2P::atomic_gemm_overlap_ag, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt)
py::call_guard<py::gil_scoped_release>()) .def("set_buffer_params", &CommOverlapP2P::set_buffer_params);
.def("atomic_gemm_overlap_rs_p2p", &CommOverlapP2P::atomic_gemm_overlap_rs,
py::call_guard<py::gil_scoped_release>())
.def("copy_input_to_ubuf", &CommOverlapP2P::copy_input_to_ubuf,
py::call_guard<py::gil_scoped_release>())
.def("get_ubuf_output", &CommOverlapP2P::get_ubuf_output,
py::call_guard<py::gil_scoped_release>())
.def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv,
py::call_guard<py::gil_scoped_release>())
.def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard<py::gil_scoped_release>())
.def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm,
py::call_guard<py::gil_scoped_release>())
.def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap,
py::call_guard<py::gil_scoped_release>());
} }
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <pybind.h>
#include "common.h"
#include "pybind.h"
#include "torch/torch.h"
#include "util.h"
namespace transformer_engine::pytorch {
constexpr size_t MXFP8_BLOCK_SIZE = 32;
Quantizer::Quantizer(const py::handle& quantizer) {
if (quantizer.is_none()) {
this->rowwise_usage = true;
this->columnwise_usage = true;
this->internal = false;
} else {
this->rowwise_usage = quantizer.attr("rowwise_usage").cast<bool>();
this->columnwise_usage = quantizer.attr("columnwise_usage").cast<bool>();
this->internal = quantizer.attr("internal").cast<bool>();
this->quantizer = quantizer;
}
}
Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
const at::Tensor& scale = quantizer.attr("scale").cast<at::Tensor>();
const at::Tensor& amax = quantizer.attr("amax").cast<at::Tensor>();
const DType type = quantizer.attr("dtype").cast<DType>();
this->amax = amax;
this->scale = scale;
this->dtype = type;
}
std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
at::TensorOptions opts;
opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA);
std::vector<int64_t> torch_shape;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
}
at::Tensor ret;
if (rowwise_data.has_value()) {
ret = std::move(*rowwise_data);
} else {
ret = at::empty(torch_shape, opts);
}
TensorWrapper tensor;
tensor.set_rowwise_data(ret.data_ptr(), dtype, shape);
return {std::move(tensor), py::cast(ret)};
}
void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()),
getTensorShape(scale));
at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
using namespace pybind11::literals;
std::vector<int64_t> rowwise_torch_shape;
std::vector<int64_t> columnwise_torch_shape;
if (!shape.empty()) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape.back()));
}
for (size_t i = 0; i < shape.size(); ++i) {
if (i < shape.size() - 1) {
columnwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
rowwise_torch_shape.emplace_back(static_cast<int64_t>(shape[i]));
}
at::TensorOptions opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(rowwise_torch_shape, opts);
}
}
const py::object py_data = rowwise_usage ? py::cast(data) : py::none();
at::Tensor columnwise_data;
bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported();
if (create_transpose) {
columnwise_data = at::empty(columnwise_torch_shape, opts);
}
const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none();
opts = opts.dtype(torch::kFloat32);
at::Tensor scale_inv = at::reciprocal(scale);
py::object ret;
if (internal) {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorBasePythonClass));
ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
"quantizer"_a = this->quantizer);
} else {
py::handle Float8TensorClass(reinterpret_cast<PyObject*>(Float8TensorPythonClass));
ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype),
"data"_a = py_data, "fp8_scale_inv"_a = scale_inv,
"fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data,
"quantizer"_a = this->quantizer);
}
TensorWrapper tensor(this->get_scaling_mode());
if (rowwise_usage) {
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
if (create_transpose) {
std::vector<size_t> transposed_shape;
for (auto s : columnwise_torch_shape) {
transposed_shape.emplace_back(static_cast<size_t>(s));
}
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape);
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)};
}
MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>();
}
void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
size_t numel = 1;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
numel *= s;
}
TensorWrapper tensor(NVTE_MXFP8_1D_SCALING);
at::TensorOptions opts;
at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv,
columnwise_scale_inv; // TODO(pgadzinski) - change
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
auto last_dim = static_cast<size_t>(torch_shape.back());
NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0,
"MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE,
" (got shape=", torch_shape, ")");
at::Tensor data;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data = std::move(*rowwise_data);
} else {
data = at::empty(torch_shape, opts);
}
auto sinv0 = roundup(numel / last_dim, 128);
auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4);
rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{sinv0, sinv1});
}
if (columnwise_usage) {
auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4);
auto sinv1 = roundup(last_dim, 128);
columnwise_data = at::empty(torch_shape, opts);
columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{sinv0, sinv1});
}
this->set_quantization_params(&tensor);
py::object ret;
if (internal) {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorBasePythonClass));
ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data,
"rowwise_scale_inv"_a = rowwise_scale_inv,
"columnwise_scale_inv"_a = columnwise_scale_inv,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
} else {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass));
ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = data, "columnwise_data"_a = columnwise_data,
"rowwise_scale_inv"_a = rowwise_scale_inv,
"columnwise_scale_inv"_a = columnwise_scale_inv,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
}
return {std::move(tensor), std::move(ret)};
}
} // namespace transformer_engine::pytorch
...@@ -9,20 +9,22 @@ ...@@ -9,20 +9,22 @@
#include <string> #include <string>
#include "common/common.h"
#include "extensions.h" #include "extensions.h"
void fused_amax_and_scale_update_after_reduction( void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
const at::Tensor &amax_reduction_buffer, std::vector<at::Tensor> amax_histories, std::vector<at::Tensor> amax_histories,
std::vector<at::Tensor> scales, std::vector<at::Tensor> scale_invs, std::vector<at::Tensor> scales,
const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin) { const std::string &amax_compute_algo,
transformer_engine::DType fp8_dtype,
float margin) {
using namespace transformer_engine; using namespace transformer_engine;
using namespace transformer_engine::pytorch;
size_t num_tensors = amax_histories.size(); size_t num_tensors = amax_histories.size();
std::vector<Tensor> t_amax_histories(num_tensors); std::vector<Tensor> t_amax_histories(num_tensors);
std::vector<Tensor> t_scales(num_tensors); std::vector<Tensor> t_scales(num_tensors);
std::vector<Tensor> t_scale_invs(num_tensors);
std::vector<NVTETensor> te_amax_histories(num_tensors); std::vector<NVTETensor> te_amax_histories(num_tensors);
std::vector<NVTETensor> te_scales(num_tensors); std::vector<NVTETensor> te_scales(num_tensors);
std::vector<NVTETensor> te_scale_invs(num_tensors);
for (size_t i = 0; i < num_tensors; i++) { for (size_t i = 0; i < num_tensors; i++) {
t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); t_amax_histories[i].data.dptr = amax_histories[i].data_ptr();
auto amax_sizes = amax_histories[i].sizes().vec(); auto amax_sizes = amax_histories[i].sizes().vec();
...@@ -36,18 +38,11 @@ void fused_amax_and_scale_update_after_reduction( ...@@ -36,18 +38,11 @@ void fused_amax_and_scale_update_after_reduction(
t_scales[i].data.shape = scale_shape; t_scales[i].data.shape = scale_shape;
t_scales[i].data.dtype = DType::kFloat32; t_scales[i].data.dtype = DType::kFloat32;
t_scale_invs[i].data.dptr = scale_invs[i].data_ptr();
auto scale_inv_sizes = scale_invs[i].sizes().vec();
std::vector<size_t> scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()};
t_scale_invs[i].data.shape = scale_inv_shape;
t_scale_invs[i].data.dtype = DType::kFloat32;
te_amax_histories[i] = reinterpret_cast<NVTETensor>(&t_amax_histories[i]); te_amax_histories[i] = reinterpret_cast<NVTETensor>(&t_amax_histories[i]);
te_scales[i] = reinterpret_cast<NVTETensor>(&t_scales[i]); te_scales[i] = reinterpret_cast<NVTETensor>(&t_scales[i]);
te_scale_invs[i] = reinterpret_cast<NVTETensor>(&t_scale_invs[i]);
} }
nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales,
te_scale_invs, amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin, amax_compute_algo.c_str(), static_cast<NVTEDType>(fp8_dtype), margin,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "extensions.h" #include "extensions.h"
at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine; using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
...@@ -38,7 +38,7 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { ...@@ -38,7 +38,7 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) {
at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine; using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous(); auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -65,7 +65,7 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r ...@@ -65,7 +65,7 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r
} }
at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) { at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) {
using namespace transformer_engine; using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
...@@ -105,7 +105,7 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa ...@@ -105,7 +105,7 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine; using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous(); auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -132,7 +132,7 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so ...@@ -132,7 +132,7 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so
} }
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine; using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
...@@ -159,7 +159,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc ...@@ -159,7 +159,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine; using namespace transformer_engine::pytorch;
auto output_grads = output_grads_.contiguous(); auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
...@@ -188,7 +188,7 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -188,7 +188,7 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
} }
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) {
using namespace transformer_engine; using namespace transformer_engine::pytorch;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
...@@ -220,7 +220,7 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float ...@@ -220,7 +220,7 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float
at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_, at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_, at::Tensor softmax_results_,
float scale_factor) { float scale_factor) {
using namespace transformer_engine; using namespace transformer_engine::pytorch;
auto output_grads = output_grad_.contiguous(); auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous(); auto softmax_results = softmax_results_.contiguous();
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "transformer_engine/transformer_engine.h"
void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) {
using namespace transformer_engine::pytorch;
if (input.scaling_mode() == NVTE_INVALID_SCALING) {
NVTE_ERROR("Invalid scaling mode for swizzle.");
} else if (input.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
return;
}
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
NVTEBasicTensor scale_inv;
if (rowwise) {
scale_inv = input.get_rowwise_scale_inv();
} else {
scale_inv = input.get_columnwise_scale_inv();
}
auto input_shape = nvte_shape_to_vector(input.shape());
auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
// Allocate memory for swizzled output.
auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
std::vector<int64_t> scale_inv_shape_int;
for (size_t i = 0; i < scale_inv_shape.size(); ++i) {
scale_inv_shape_int.push_back(static_cast<int64_t>(scale_inv_shape[i]));
}
auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options);
void* scale_inv_dptr = scale_inv.data_ptr;
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
// Reconstruct input only to avoid swizzling both directions if not needed.
// Use any 8 bit type, it's irrelevant.
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING);
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
if (rowwise) {
input_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
} else {
input_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
output_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0,
scale_inv_shape);
}
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
}
}
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);
void* scale_inv_dptr = getDataPtr(scale_inv, 0);
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), getTensorShape(input),
DType::kFloat8E4M3, nullptr, nullptr, scale_inv_dptr,
getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING);
auto output_cu = makeTransformerEngineTensor(
input.data_ptr(), getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
swizzled_scale_inv_dptr, getTensorShape(swizzled_scale_inv), NVTE_MXFP8_1D_SCALING);
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return swizzled_scale_inv;
}
at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);
// Return immediately if tensor is empty
if (scale_inv.numel() == 0) {
return swizzled_scale_inv;
}
void* scale_inv_dptr = getDataPtr(scale_inv, 0);
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
auto input_cu = makeTransformerEngineTensor(
nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
nullptr, scale_inv_dptr, {1}, getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING);
auto output_cu = makeTransformerEngineTensor(
nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
nullptr, swizzled_scale_inv_dptr, {1}, getTensorShape(swizzled_scale_inv),
NVTE_MXFP8_1D_SCALING);
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return swizzled_scale_inv;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment