Unverified Commit bd278fff authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Enable MXFP8 LayerNorm and RMSNorm (#1487)



* Enable MXFP8 LayerNorm and RMSNorm
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix compilation
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix envvar
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6ff7b704
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "common/util/system.h"
#include "extensions.h" #include "extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
...@@ -70,80 +71,85 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -70,80 +71,85 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
} }
std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
float eps, py::object ln_out, py::handle quantizer, float eps, py::object out, py::handle quantizer,
DType out_dtype, const int sm_margin, DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
using namespace transformer_engine; using namespace transformer_engine;
// Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none); const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none); const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none);
TensorWrapper bias_cu;
TensorWrapper bias_tensor;
MaybeTensor bias_grad = std::nullopt;
if (bias.has_value()) { if (bias.has_value()) {
bias_tensor = makeTransformerEngineTensor(*bias); bias_cu = makeTransformerEngineTensor(*bias);
} }
// Tensor dimensions // Tensor dimensions
size_t N = static_cast<size_t>(input_tensor.size(0)); const size_t N = static_cast<size_t>(input_cu.size(0));
size_t H = static_cast<size_t>(input_tensor.size(1)); const size_t H = static_cast<size_t>(input_cu.size(1));
std::vector<size_t> size = {N, H}; const std::vector<size_t> size = {N, H};
// Construct Transformer Engine tensors // Tensors to save for backward pass
at::Tensor mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); at::Tensor mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
at::Tensor rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); at::Tensor rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
TensorWrapper mu_cu = makeTransformerEngineTensor(mu);
TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma);
TensorWrapper ln_out_tensor; // Output tensor
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer); std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
py::object ln_output; TensorWrapper out_cu;
if (out.is_none()) {
std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype);
} else {
out_cu = makeTransformerEngineTensor(out, quantizer);
}
// Determine whether to avoid fused kernel
bool force_unfused_kernel = false;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
// Use high precision output from normalization if (!transformer_engine::getenv<bool>("NVTE_CUDNN_MXFP8_NORM", false)) {
NoneQuantizer q{none}; // TE only supports MXFP8 norm with cuDNN backend
std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, out_dtype); force_unfused_kernel = true;
} else { } else if (N % 128 != 0 || H % 128 != 0) {
if (ln_out.is_none()) { // cuDNN norm requires full tile for MXFP8
std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype); force_unfused_kernel = true;
} else {
ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
} }
} }
TensorWrapper mu_cu = makeTransformerEngineTensor(mu); TensorWrapper unquantized_out_cu;
TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); if (force_unfused_kernel) {
NoneQuantizer q{none};
py::object unquantized_out;
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace sizes // Query workspace size
transformer_engine::TensorWrapper workspace; transformer_engine::TensorWrapper workspace;
nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps, nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
// Allocate workspaces // Allocate workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel // Launch kernel
nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps, nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { // Quantize output if using unfused kernel
TensorWrapper cast_out_tensor; if (force_unfused_kernel) {
if (ln_out.is_none()) { nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype);
} else {
cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
}
nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
return {ln_out, py::cast(mu), py::cast(rsigma)}; return {out, py::cast(mu), py::cast(rsigma)};
} }
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
...@@ -187,69 +193,77 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, ...@@ -187,69 +193,77 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
} }
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object ln_out, py::handle quantizer, py::object out, py::handle quantizer,
transformer_engine::DType otype, const int sm_margin, transformer_engine::DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) { const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
using namespace transformer_engine; using namespace transformer_engine;
// Input and param tensors
auto none = py::none(); auto none = py::none();
const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none); const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none); const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none);
// Tensor dimensions // Tensor dimensions
size_t N = static_cast<size_t>(input_tensor.shape().data[0]); const size_t N = static_cast<size_t>(input_cu.shape().data[0]);
size_t H = static_cast<size_t>(input_tensor.shape().data[1]); const size_t H = static_cast<size_t>(input_cu.shape().data[1]);
const std::vector<size_t> size = {N, H};
// Construct Transformer Engine tensors // Tensors to save for backward pass
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat)); auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
std::vector<size_t> size = {N, H}; auto rsigma_cu = makeTransformerEngineTensor(rsigma);
TensorWrapper ln_out_tensor;
// Output tensor
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer); std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
py::object ln_output; TensorWrapper out_cu;
if (out.is_none()) {
std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype);
} else {
out_cu = makeTransformerEngineTensor(out, quantizer);
}
// Determine whether to avoid fused kernel
bool force_unfused_kernel = false;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
// Use high precision output from normalization if (!transformer_engine::getenv<bool>("NVTE_CUDNN_MXFP8_NORM", false)) {
NoneQuantizer q{none}; // TE only supports MXFP8 norm with cuDNN backend
std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, otype); force_unfused_kernel = true;
} else { } else if (N % 128 != 0 || H % 128 != 0) {
if (ln_out.is_none()) { // cuDNN norm requires full tile for MXFP8
std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype); force_unfused_kernel = true;
} else {
ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
} }
} }
auto rsigma_cu = makeTransformerEngineTensor(rsigma); TensorWrapper unquantized_out_cu;
if (force_unfused_kernel) {
NoneQuantizer q{none};
py::object unquantized_out;
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace sizes // Query workspace size
transformer_engine::TensorWrapper workspace; transformer_engine::TensorWrapper workspace;
nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(), nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
rsigma_cu.data(), workspace.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
// Allocate workspaces // Allocate workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel // Launch kernel
nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(), nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
rsigma_cu.data(), workspace.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream()); zero_centered_gamma, at::cuda::getCurrentCUDAStream());
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { // Quantize output if using unfused kernel
TensorWrapper cast_out_tensor; if (force_unfused_kernel) {
if (ln_out.is_none()) { nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype);
} else {
cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
}
nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
return {ln_out, py::none(), py::cast(rsigma)}; return {out, py::none(), py::cast(rsigma)};
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment