Unverified Commit fdc09f42 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Exposing RMSNorm in pyTorch (#306)



* Exposing RMSNorm in pyTorch extensions
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* First pass at the Python API
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Small fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added numerics tests and fixed issues
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Lint fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added RMSNorm to LayerNormMLP
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added ONNX export and tests for RMSNorm
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix python lint
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix BERT case
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added normalization option to the TransformerLayer
Added tests
Fixed test failures
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix documentation
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix kwarg bug
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix IMA and invalid type error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Increase RMSNorm threshold for bf16 case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 06d5fa97
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
at::Tensor scaled_softmax_forward(at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads;
}
at::Tensor scaled_masked_softmax_forward(at::Tensor input,
at::Tensor mask,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
if (!input.is_contiguous())
input = input.contiguous();
if (!mask.is_contiguous())
mask = mask.contiguous();
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
TORCH_CHECK(pad_batches == 1 || pad_batches == batches);
TORCH_CHECK(mask.size(1) == 1);
TORCH_CHECK(mask.size(2) == query_seq_len);
TORCH_CHECK(mask.size(3) == key_seq_len);
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto mask_cu = makeTransformerEngineTensor(mask);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_masked_softmax_forward(
input_cu.data(), mask_cu.data(), softmax_results_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads;
}
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_CHECK(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(),
softmax_results_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
TORCH_CHECK(output_grads.size(1) == output_grads.size(2));
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_upper_triang_masked_softmax_backward(output_grads_cu.data(),
softmax_results_cu.data(),
output_grads_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return output_grads;
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void fused_cast_transpose(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto grad_output_cast =
allocateTorchTensor(grad_output.size(0),
grad_output.size(1),
DType::kByte);
auto grad_output_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(),
scale.data_ptr(), scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
return {grad_bias, grad_output_cast, grad_output_transpose};
}
std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
transformer_engine::DType grad_bias_type
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type);
auto grad_output_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(),
scale.data_ptr(), scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
return {grad_bias, grad_output_transpose};
}
std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor gelu_input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto dgelu =
allocateTorchTensor(grad_output.size(0),
grad_output.size(1),
DType::kByte);
auto dgelu_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
transformer_engine::TensorWrapper workspace;
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input);
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
return {grad_bias, dgelu, dgelu_transpose};
}
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
std::vector<at::Tensor> transposed_output_list,
std::vector<at::Tensor> amax_list,
std::vector<at::Tensor> scale_inv_list,
transformer_engine::DType otype
) {
using namespace transformer_engine;
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, scale_dptr_list,
cast_output_dptr_list, transposed_output_dptr_list,
amax_dptr_list, scale_inv_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, scale_shape_list,
cast_output_shape_list, transposed_output_shape_list,
amax_shape_list, scale_inv_shape_list;
std::vector<transformer_engine::DType> input_type_list, scale_type_list,
cast_output_type_list, transposed_output_type_list,
amax_type_list, scale_inv_type_list;
auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor,
std::vector<void*>& dptr_list,
std::vector<std::vector<size_t>>& shape_list) {
dptr_list.push_back(tensor.data_ptr());
shape_list.push_back({});
for (int d = 0; d < tensor.dim(); ++d) {
shape_list.back().push_back(tensor.size(d));
}
};
auto extract_tensor_props = [](at::Tensor& tensor,
std::vector<void*>& dptr_list,
std::vector<std::vector<size_t>>& shape_list,
std::vector<transformer_engine::DType>& type_list) {
dptr_list.push_back(tensor.data_ptr());
shape_list.push_back({});
for (int d = 0; d < tensor.dim(); ++d) {
shape_list.back().push_back(tensor.size(d));
}
type_list.push_back(GetTransformerEngineDType(tensor.scalar_type()));
};
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
extract_tensor_props(input_list[tensor_id],
input_dptr_list,
input_shape_list,
input_type_list);
extract_tensor_props(scale_list[tensor_id],
scale_dptr_list,
scale_shape_list,
scale_type_list);
extract_tensor_props_skip_dtype(cast_output_list[tensor_id],
cast_output_dptr_list,
cast_output_shape_list);
cast_output_type_list.push_back(otype);
extract_tensor_props_skip_dtype(transposed_output_list[tensor_id],
transposed_output_dptr_list,
transposed_output_shape_list);
transposed_output_type_list.push_back(otype);
extract_tensor_props(amax_list[tensor_id],
amax_dptr_list,
amax_shape_list,
amax_type_list);
extract_tensor_props(scale_inv_list[tensor_id],
scale_inv_dptr_list,
scale_inv_shape_list,
scale_inv_type_list);
}
transformer_engine::TensorWrapper workspace;
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list,
nvte_cast_output_list, nvte_transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr,
const std::vector<size_t>& shape,
transformer_engine::DType dtype,
void* amax_dptr,
void* scale_dptr,
void* scale_inv_dptr)
-> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr,
scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
nvte_input_list.emplace_back(make_tensor(input_dptr_list[i],
input_shape_list[i],
input_type_list[i],
nullptr,
nullptr,
nullptr));
nvte_cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
cast_output_shape_list[i],
cast_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
nvte_transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
transposed_output_shape_list[i],
transposed_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
}
// Check tensor lists
NVTE_CHECK(nvte_cast_output_list.size() == nvte_input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(nvte_transposed_output_list.size() == nvte_input_list.size(),
"Number of input and T output tensors must match");
// Launch TE kernel
nvte_multi_cast_transpose(nvte_input_list.size(),
nvte_input_list.data(),
nvte_cast_output_list.data(),
nvte_transposed_output_list.data(),
at::cuda::getCurrentCUDAStream());
}
at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto output =
allocateTorchTensor(input.size(1),
input.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
......@@ -328,6 +328,44 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
return output;
}
at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype,
const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
at::Tensor output = rmsnorm_fwd_fp8_inf(input,
weight,
eps_float,
scale,
amax,
scale_inv,
otype_arg,
zero_centered_gamma);
return output;
}
at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps);
at::Tensor output = rmsnorm_fwd_inf(input,
weight,
eps_float,
zero_centered_gamma);
return output;
}
TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
......@@ -339,4 +377,6 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
m.def("rmsnorm_fwd_inf_ts", &rmsnorm_fwd_inf_ts);
}
......@@ -7,3 +7,4 @@ from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm
from .rmsnorm import RMSNorm
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Internal function used by multiple modules."""
from typing import Union, Dict, Any
import torch
from .. import cpp_extensions as tex
from ..fp8 import get_fp8_te_dtype
def _get_normalization_func(normalization: str,
fp8_output: bool,
is_grad_enabled: bool,
forward: bool):
fwd_normalization_funcs = {
('LayerNorm', True, True): tex.layernorm_fwd_fp8,
('LayerNorm', True, False): tex.layernorm_fwd_fp8_inf,
('LayerNorm', False, True): tex.layernorm_fwd_noalloc,
('LayerNorm', False, False): tex.layernorm_fwd_inf,
('RMSNorm', True, True): tex.rmsnorm_fwd_fp8,
('RMSNorm', True, False): tex.rmsnorm_fwd_fp8_inf,
('RMSNorm', False, True): tex.rmsnorm_fwd_noalloc,
('RMSNorm', False, False): tex.rmsnorm_fwd_inf,
}
bwd_normalization_funcs = {
'LayerNorm': tex.layernorm_bwd,
'RMSNorm': tex.rmsnorm_bwd,
}
if forward:
return fwd_normalization_funcs[(normalization, fp8_output, is_grad_enabled)]
assert not fp8_output, "FP8 output is not supported in backward normalization!"
assert is_grad_enabled, "Gradient has to be enabled to call backward normalization!"
return bwd_normalization_funcs[normalization]
def _apply_normalization(inputmat:torch.Tensor,
ln_out: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: Union[torch.Tensor, None],
eps: float,
fp8_out: bool,
fp8_meta: Dict[str, Any],
normalization: str,
fwd_ln_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool):
normalization_func = _get_normalization_func(normalization,
fp8_out,
is_grad_enabled,
True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if fp8_out:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if is_grad_enabled:
output_key = "ln_out" if normalization == "LayerNorm" else "rmsnorm_out"
output_kwarg = {output_key: ln_out}
output = normalization_func(
*inputs,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
**output_kwarg,
)
else:
return normalization_func(
*inputs,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
zero_centered_gamma,
), None, None
else:
if is_grad_enabled:
output = normalization_func(
*inputs, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
return normalization_func(
*inputs, eps, zero_centered_gamma
), None, None
if normalization == "RMSNorm":
output = (ln_out, None, output[1])
elif normalization == "LayerNorm":
output = (ln_out, output[1], output[2])
return output
......@@ -12,7 +12,7 @@ import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import transformer_engine_extensions as tex
from .. import cpp_extensions as tex
from .base import (
get_workspace,
......@@ -38,22 +38,13 @@ from ..distributed import (
reduce_scatter_along_first_dim,
gather_along_first_dim,
)
from ..cpp_extensions import (
fp8_gemm,
gemm,
fp8_cast_transpose_fused,
layernorm_fwd_fp8,
layernorm_fwd_fp8_inf,
layernorm_fwd_inf,
cast_to_fp8,
cast_from_fp8,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ._common import _apply_normalization
__all__ = ["LayerNormLinear"]
__all__ = ["LayerNormLinear"]
class _LayerNormLinear(torch.autograd.Function):
"""LayerNormLinear semi-top level module
......@@ -65,7 +56,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor,
weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None],
......@@ -91,6 +82,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool,
ub_split_ag: bool,
normalization: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -105,10 +97,9 @@ class _LayerNormLinear(torch.autograd.Function):
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
......@@ -118,69 +109,35 @@ class _LayerNormLinear(torch.autograd.Function):
dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_grad_enabled:
if not ub_split_ag:
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
ln_out = ln_out
)
else:
mu = rsigma = None
ln_out = layernorm_fwd_fp8_inf(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
zero_centered_gamma,
)
else:
if is_grad_enabled:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out_return, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out = cast_to_fp8(
ln_out_return,
else:
ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
ln_out, mu, rsigma = _apply_normalization(inputmat,
ln_out,
ln_weight,
ln_bias,
eps,
fp8 and not return_layernorm_output,
fp8_meta,
normalization,
fwd_ln_sm_margin,
zero_centered_gamma,
is_grad_enabled)
# If residual connection is after LN, we need `ln_out_return`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if return_layernorm_output:
ln_out_return = ln_out
if fp8:
ln_out = tex.cast_to_fp8(
ln_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
if is_grad_enabled:
if ub_split_ag:
_, mu, rsigma = tex.layernorm_fwd_noalloc(
inputmat, ln_weight, ln_bias, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out_return = ln_out
# Column Parallel Linear
if ub_split_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
......@@ -200,7 +157,7 @@ class _LayerNormLinear(torch.autograd.Function):
if update_fp8_weights:
if is_grad_enabled:
fp8_cast_transpose_fused(
tex.fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -210,13 +167,13 @@ class _LayerNormLinear(torch.autograd.Function):
)
else:
weight_t_fp8 = None
weight_fp8 = cast_to_fp8(
weight_fp8 = tex.cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward)
out = fp8_gemm(
out = tex.fp8_gemm(
weight_fp8,
fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -247,7 +204,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float()
out, _, _ = gemm(
out, _, _ = tex.gemm(
weight,
ln_out_total,
activation_dtype,
......@@ -289,6 +246,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
......@@ -379,7 +337,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
# DGRAD: Evaluated unconditionally to feed into Linear backward
_ = fp8_gemm(
_ = tex.fp8_gemm(
weight_t_fp8,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT,
......@@ -397,7 +355,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
else:
# DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = gemm(
_, _, _ = tex.gemm(
weight,
grad_output,
ctx.activation_dtype,
......@@ -427,7 +385,7 @@ class _LayerNormLinear(torch.autograd.Function):
# WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = fp8_gemm(
wgrad = tex.fp8_gemm(
ln_out_total_t,
fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT,
......@@ -446,14 +404,14 @@ class _LayerNormLinear(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
)
else:
ln_out_total_c = cast_from_fp8(
ln_out_total_c = tex.cast_from_fp8(
ln_out_total,
ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
TE_DType[ctx.activation_dtype],
)
wgrad, _, _ = gemm(
wgrad, _, _ = tex.gemm(
ln_out_total_c,
grad_output,
ctx.activation_dtype,
......@@ -468,7 +426,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
else:
# WGRAD
wgrad, grad_bias, _ = gemm(
wgrad, grad_bias, _ = tex.gemm(
ln_out_total,
grad_output,
ctx.activation_dtype,
......@@ -496,10 +454,18 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
if ctx.normalization == "LayerNorm":
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
elif ctx.normalization == "RMSNorm":
dxmat, dgamma = tex.rmsnorm_bwd(
d_ln_out, inputmat, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
dbeta = None
if not ctx.use_bias:
grad_bias = None
......@@ -533,6 +499,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -555,6 +522,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
......@@ -624,6 +593,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
get_rng_state_tracker: Optional[Callable] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
normalization: str = 'LayerNorm',
return_bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
......@@ -649,9 +619,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.normalization = normalization
assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!"
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.apply_bias = self.use_bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma
......@@ -696,15 +668,18 @@ class LayerNormLinear(TransformerEngineBaseModule):
dtype=params_dtype,
)
)
self.layer_norm_bias = Parameter(
torch.empty(
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter(
torch.empty(
in_features,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
else:
self.layer_norm_bias = None
self.reset_layer_norm_parameters()
self.weight_tensor = torch.empty(
......@@ -796,7 +771,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
init.ones_(self.layer_norm_weight)
else:
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
if self.layer_norm_bias is not None:
init.zeros_(self.layer_norm_bias)
def get_fp8_weights_scratchpad(
self,
......@@ -915,6 +891,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_split_ag,
self.normalization,
)
out = fwd_fn(*args)
......
......@@ -46,6 +46,8 @@ from .. import cpp_extensions as tex
from ..constants import dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ._common import _apply_normalization
__all__ = ["LayerNormMLP"]
......@@ -107,6 +109,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_split_rs: bool,
ub_split_ag: bool,
activation: str,
normalization: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
......@@ -124,7 +127,8 @@ class _LayerNormMLP(torch.autograd.Function):
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag:
tp_world_size = get_distributed_world_size(tp_group)
......@@ -133,70 +137,39 @@ class _LayerNormMLP(torch.autograd.Function):
if ub_split_ag:
ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0)
else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_split_rs:
tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1:
ub_split_rs = False
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
ln_out, mu, rsigma = _apply_normalization(inputmat,
ln_out,
ln_weight,
ln_bias,
eps,
fp8 and not return_layernorm_output,
fp8_meta,
normalization,
fwd_ln_sm_margin,
zero_centered_gamma,
is_grad_enabled)
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_grad_enabled:
if not ub_split_ag:
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = tex.layernorm_fwd_fp8(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
ln_out = ln_out,
)
else:
ln_out = tex.layernorm_fwd_fp8_inf(
inputmat,
ln_weight,
ln_bias,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
zero_centered_gamma,
)
else:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
if return_layernorm_output:
ln_out_return = ln_out
if fp8:
ln_out = tex.cast_to_fp8(
ln_out_return,
ln_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
)
else:
if is_grad_enabled:
if ub_split_ag:
_, mu, rsigma = tex.layernorm_fwd_noalloc(
inputmat, ln_weight, ln_bias, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out_return = ln_out
# Column Parallel Linear
if ub_split_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1)
......@@ -422,6 +395,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_split_ag = ub_split_ag
ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
# Row Parallel Linear
if ub_split_rs:
......@@ -804,10 +778,17 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
if ctx.normalization == "LayerNorm":
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
elif ctx.normalization == "RMSNorm":
dxmat, dgamma = tex.rmsnorm_bwd(
d_ln_out, inputmat, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
dbeta = None
return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
......@@ -846,6 +827,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -864,6 +846,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True`
if set to `False`, the FC1 and FC2 layers will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'.
......@@ -942,6 +926,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
tp_size: int = 1,
init_method: Optional[Callable] = None,
bias: bool = True,
normalization: str = 'LayerNorm',
activation : str = "gelu",
output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False,
......@@ -960,6 +945,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.normalization = normalization
assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!"
self.use_bias = bias
self.activation = activation
self.return_bias = return_bias
......@@ -1005,15 +992,18 @@ class LayerNormMLP(TransformerEngineBaseModule):
dtype=params_dtype,
)
)
self.layer_norm_bias = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
else:
self.layer_norm_bias = None
self.reset_layer_norm_parameters()
if self.activation in ['reglu', 'geglu', 'swiglu']:
......@@ -1114,7 +1104,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
init.ones_(self.layer_norm_weight)
else:
init.zeros_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
if self.layer_norm_bias is not None:
init.zeros_(self.layer_norm_bias)
def get_fp8_weights_scratchpad(
self,
......@@ -1217,6 +1208,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_split_rs,
self.ub_split_ag,
self.activation,
self.normalization,
)
out = fwd_fn(*args)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""RMSNorm API"""
import os
from typing import Union, Tuple, Optional
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from .. import cpp_extensions as tex
from ..jit import no_torch_dynamo
__all__ = ["RMSNorm"]
class _RMSNorm(torch.autograd.Function):
"""functional RMSNorm"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
rmsnorm_weight: torch.Tensor,
eps: float,
fwd_rmsnorm_sm_margin: int,
bwd_rmsnorm_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = rmsnorm_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "RMSNorm not possible"
inputmat = inp.view((-1, in_features))
if is_grad_enabled:
rmsnorm_out, rsigma = tex.rmsnorm_fwd(inputmat, rmsnorm_weight,
eps, fwd_rmsnorm_sm_margin,
zero_centered_gamma)
ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
else:
rmsnorm_out = tex.rmsnorm_fwd_inf(inputmat, rmsnorm_weight,
eps,
zero_centered_gamma)
return rmsnorm_out.view_as(inp)
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_rmsnorm_out = grad_output.view(inputmat.shape)
dxmat, dgamma = tex.rmsnorm_bwd(
d_rmsnorm_out, inputmat, rsigma, rmsnorm_weight,
ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma
)
return (
dxmat.view(ctx.inp_shape),
dgamma,
None,
None,
None,
None,
None,
)
class RMSNorm(torch.nn.Module):
r"""
Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in
the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math::
y = \frac{x}{RMS(x) + \varepsilon} * \gamma
where
.. math::
RMS(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2}
:math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size`
Parameters
----------
hidden_size : int
size of each input sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in RMSNorm is initialized to 0 and
the RMSNorm formula changes to
.. math::
y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma)
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.weight, "sequence_parallel", sequence_parallel)
self.reset_rms_norm_parameters()
# These many SMs are subtracted from the total SM count when calling forward
# and backward RMSNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with RMSNorm.
self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def reset_rms_norm_parameters(self) -> None:
"""Init RMSNorm params"""
if not self.zero_centered_gamma:
init.ones_(self.weight)
else:
init.zeros_(self.weight)
@no_torch_dynamo
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""RMSNorm FWD"""
if torch.is_grad_enabled():
fwd_fn = _RMSNorm.apply
args = []
else:
fwd_fn = _RMSNorm.forward
args = [None]
args += (
inp,
self.weight,
self.eps,
self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled()
)
return fwd_fn(*args)
......@@ -283,6 +283,20 @@ def onnx_te_gemm(
return output
def _ones_like(g, inp, dtype):
"""Returns a tensor filled with the scalar value 1, with the same size as input and
with dtype data-type"""
shape = g.op("Shape", inp)
# WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR
# create a ConstantOfShape with type FP32 and then add a Cast to BF16.
is_bf16 = dtype == torch.bfloat16
one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1],
dtype=torch.float32 if is_bf16 else dtype))
if is_bf16:
one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return one
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma):
......@@ -305,19 +319,6 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
"""ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument
def ones_like(inp, dtype):
"""Returns a tensor filled with the scalar value 1, with the same size as input and
with dtype data-type"""
shape = g.op("Shape", inp)
# WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR
# create a ConstantOfShape with type FP32 and then add a Cast to BF16.
is_bf16 = dtype == torch.bfloat16
one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1],
dtype=torch.float32 if is_bf16 else dtype))
if is_bf16:
one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return one
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
if normalized_shape is None:
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs)
......@@ -328,7 +329,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
if zero_centered_gamma:
inputs_dtype = inputs.type().dtype()
one = ones_like(weight, inputs_dtype)
one = _ones_like(g, weight, inputs_dtype)
weight = g.op("Add", weight, one)
axis = -len(normalized_shape)
......@@ -344,6 +345,57 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
)
return ln
@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "b")
def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma):
"""ONNX graph for rmsnorm_fwd_fp8"""
# pylint: disable=unused-argument
inp_dtype = get_TensorProtoDataType(inputs)
if inp_dtype != get_TensorProtoDataType(weight):
weight = g.op("Cast", weight, to_i=inp_dtype)
ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln
@symbolic_helper.parse_args("v", "v", "f", "b")
def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma):
"""ONNX graph for rmsnorm_fwd"""
# pylint: disable=unused-argument
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
if normalized_shape is None:
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs)
assert ndim is not None
normalized_shape = list(range(0, ndim))
# Normalization axis = 0, so normalized_shape uses all dims except dim = 0
normalized_shape = normalized_shape[1:]
if zero_centered_gamma:
inputs_dtype = inputs.type().dtype()
one = _ones_like(g, weight, inputs_dtype)
weight = g.op("Add", weight, one)
axis = -len(normalized_shape)
inputs_float = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
norm = g.op("ReduceL2", inputs_float, axes_i=[axis])
shape = g.op("Shape", inputs_float, start_i=-1)
shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT)
n_reciprocal = g.op("Reciprocal", shape_f)
sqrt_n_reciprocal = g.op("Sqrt", n_reciprocal)
rms = g.op("Mul", norm, sqrt_n_reciprocal)
eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32))
rms_eps = g.op("Add", rms, eps_tensor)
normalized_input = g.op("Div", inputs_float, rms_eps)
result = g.op("Mul", weight, normalized_input)
result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs))
return result
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
......@@ -355,3 +407,5 @@ register_custom_op_symbolic('tex_ts::swiglu_ts', onnx_fp8_swiglu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
register_custom_op_symbolic('tex_ts::rmsnorm_fwd_fp8_inf_ts', onnx_rmsnorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::rmsnorm_fwd_inf_ts', onnx_rmsnorm_fwd, VER)
......@@ -11,7 +11,7 @@ from typing import Any, Callable, Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.pytorch.attention import MultiHeadAttention
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
......@@ -128,6 +128,8 @@ class TransformerLayer(torch.nn.Module):
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
qkv_weight_interleaved : bool, default = `True`
if set to `False`, the QKV weight is interpreted as a concatenation of
query, key, and value weights along the `0th` dimension. The default
......@@ -220,7 +222,8 @@ class TransformerLayer(torch.nn.Module):
qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False,
bias: bool = True,
activation: str = 'gelu'
activation: str = 'gelu',
normalization: str = "LayerNorm",
) -> None:
super().__init__()
......@@ -312,6 +315,7 @@ class TransformerLayer(torch.nn.Module):
input_layernorm=not output_layernorm,
attention_type="self",
bias=bias,
normalization=normalization,
)
if layer_type == "decoder":
......@@ -322,6 +326,7 @@ class TransformerLayer(torch.nn.Module):
input_layernorm=True,
attention_type="cross",
bias=bias,
normalization=normalization,
)
# LayerNorm -> activation(Linear + Bias) -> Linear
......@@ -353,6 +358,7 @@ class TransformerLayer(torch.nn.Module):
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
activation=activation,
normalization=normalization,
)
self.hidden_dropout = hidden_dropout
......@@ -376,8 +382,12 @@ class TransformerLayer(torch.nn.Module):
hidden_size, seq_length, micro_batch_size
)
norm_module = {
"LayerNorm": LayerNorm,
"RMSNorm": RMSNorm,
}
if self.output_layernorm:
self.layernorm = LayerNorm(
self.layernorm = norm_module[normalization](
hidden_size,
eps=layernorm_epsilon,
sequence_parallel=self.sequence_parallel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment