"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "cd54a8cd15efdc9c97c28d0c8816cafe8c35c4d8"
Unverified Commit bd278fff authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Enable MXFP8 LayerNorm and RMSNorm (#1487)



* Enable MXFP8 LayerNorm and RMSNorm
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix compilation
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix envvar
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6ff7b704
......@@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "common/util/system.h"
#include "extensions.h"
namespace transformer_engine::pytorch {
......@@ -70,80 +71,85 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
}
std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
float eps, py::object ln_out, py::handle quantizer,
float eps, py::object out, py::handle quantizer,
DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
// Input and param tensors
auto none = py::none();
const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none);
TensorWrapper bias_tensor;
MaybeTensor bias_grad = std::nullopt;
const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none);
TensorWrapper bias_cu;
if (bias.has_value()) {
bias_tensor = makeTransformerEngineTensor(*bias);
bias_cu = makeTransformerEngineTensor(*bias);
}
// Tensor dimensions
size_t N = static_cast<size_t>(input_tensor.size(0));
size_t H = static_cast<size_t>(input_tensor.size(1));
std::vector<size_t> size = {N, H};
const size_t N = static_cast<size_t>(input_cu.size(0));
const size_t H = static_cast<size_t>(input_cu.size(1));
const std::vector<size_t> size = {N, H};
// Construct Transformer Engine tensors
// Tensors to save for backward pass
at::Tensor mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
at::Tensor rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
TensorWrapper mu_cu = makeTransformerEngineTensor(mu);
TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma);
TensorWrapper ln_out_tensor;
// Output tensor
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
py::object ln_output;
TensorWrapper out_cu;
if (out.is_none()) {
std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype);
} else {
out_cu = makeTransformerEngineTensor(out, quantizer);
}
// Determine whether to avoid fused kernel
bool force_unfused_kernel = false;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
// Use high precision output from normalization
NoneQuantizer q{none};
std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, out_dtype);
} else {
if (ln_out.is_none()) {
std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype);
} else {
ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
if (!transformer_engine::getenv<bool>("NVTE_CUDNN_MXFP8_NORM", false)) {
// TE only supports MXFP8 norm with cuDNN backend
force_unfused_kernel = true;
} else if (N % 128 != 0 || H % 128 != 0) {
// cuDNN norm requires full tile for MXFP8
force_unfused_kernel = true;
}
}
TensorWrapper mu_cu = makeTransformerEngineTensor(mu);
TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma);
TensorWrapper unquantized_out_cu;
if (force_unfused_kernel) {
NoneQuantizer q{none};
py::object unquantized_out;
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace sizes
// Query workspace size
transformer_engine::TensorWrapper workspace;
nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps,
ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(),
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
// Allocate workspaces
// Allocate workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
nvte_layernorm_fwd(input_tensor.data(), weight_tensor.data(), bias_tensor.data(), eps,
ln_out_tensor.data(), mu_cu.data(), rsigma_cu.data(), workspace.data(),
nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(),
mu_cu.data(), rsigma_cu.data(), workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
TensorWrapper cast_out_tensor;
if (ln_out.is_none()) {
std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, out_dtype);
} else {
cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
}
nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr,
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream());
}
return {ln_out, py::cast(mu), py::cast(rsigma)};
return {out, py::cast(mu), py::cast(rsigma)};
}
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
......@@ -187,69 +193,77 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
}
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
py::object ln_out, py::handle quantizer,
transformer_engine::DType otype, const int sm_margin,
py::object out, py::handle quantizer,
transformer_engine::DType out_dtype, const int sm_margin,
const bool zero_centered_gamma) {
using namespace transformer_engine::pytorch;
using namespace transformer_engine;
// Input and param tensors
auto none = py::none();
const TensorWrapper &input_tensor = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_tensor = makeTransformerEngineTensor(weight, none);
const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none);
const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none);
// Tensor dimensions
size_t N = static_cast<size_t>(input_tensor.shape().data[0]);
size_t H = static_cast<size_t>(input_tensor.shape().data[1]);
const size_t N = static_cast<size_t>(input_cu.shape().data[0]);
const size_t H = static_cast<size_t>(input_cu.shape().data[1]);
const std::vector<size_t> size = {N, H};
// Construct Transformer Engine tensors
// Tensors to save for backward pass
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
std::vector<size_t> size = {N, H};
TensorWrapper ln_out_tensor;
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
// Output tensor
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
py::object ln_output;
TensorWrapper out_cu;
if (out.is_none()) {
std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype);
} else {
out_cu = makeTransformerEngineTensor(out, quantizer);
}
// Determine whether to avoid fused kernel
bool force_unfused_kernel = false;
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
// Use high precision output from normalization
NoneQuantizer q{none};
std::tie(ln_out_tensor, ln_output) = q.create_tensor(size, otype);
} else {
if (ln_out.is_none()) {
std::tie(ln_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype);
} else {
ln_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
if (!transformer_engine::getenv<bool>("NVTE_CUDNN_MXFP8_NORM", false)) {
// TE only supports MXFP8 norm with cuDNN backend
force_unfused_kernel = true;
} else if (N % 128 != 0 || H % 128 != 0) {
// cuDNN norm requires full tile for MXFP8
force_unfused_kernel = true;
}
}
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
TensorWrapper unquantized_out_cu;
if (force_unfused_kernel) {
NoneQuantizer q{none};
py::object unquantized_out;
std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype);
}
TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu;
// Query workspace sizes
// Query workspace size
transformer_engine::TensorWrapper workspace;
nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(),
rsigma_cu.data(), workspace.data(),
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
// Allocate workspaces
// Allocate workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace =
makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype());
// Launch kernel
nvte_rmsnorm_fwd(input_tensor.data(), weight_tensor.data(), eps, ln_out_tensor.data(),
rsigma_cu.data(), workspace.data(),
nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(),
workspace.data(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
zero_centered_gamma, at::cuda::getCurrentCUDAStream());
if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) {
TensorWrapper cast_out_tensor;
if (ln_out.is_none()) {
std::tie(cast_out_tensor, ln_out) = my_quantizer->create_tensor(size, otype);
} else {
cast_out_tensor = makeTransformerEngineTensor(ln_out, quantizer);
}
nvte_quantize_noop(ln_out_tensor.data(), cast_out_tensor.data(), nullptr,
// Quantize output if using unfused kernel
if (force_unfused_kernel) {
nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr,
at::cuda::getCurrentCUDAStream());
}
return {ln_out, py::none(), py::cast(rsigma)};
return {out, py::none(), py::cast(rsigma)};
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment