Unverified Commit fdc09f42 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Exposing RMSNorm in pyTorch (#306)



* Exposing RMSNorm in pyTorch extensions
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* First pass at the Python API
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Small fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added numerics tests and fixed issues
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Lint fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added RMSNorm to LayerNormMLP
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added ONNX export and tests for RMSNorm
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix python lint
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix BERT case
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Added normalization option to the TransformerLayer
Added tests
Fixed test failures
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix documentation
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix kwarg bug
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix IMA and invalid type error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Increase RMSNorm threshold for bf16 case
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 06d5fa97
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
at::Tensor scaled_softmax_forward(at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads;
}
at::Tensor scaled_masked_softmax_forward(at::Tensor input,
at::Tensor mask,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
if (!input.is_contiguous())
input = input.contiguous();
if (!mask.is_contiguous())
mask = mask.contiguous();
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
TORCH_CHECK(pad_batches == 1 || pad_batches == batches);
TORCH_CHECK(mask.size(1) == 1);
TORCH_CHECK(mask.size(2) == query_seq_len);
TORCH_CHECK(mask.size(3) == key_seq_len);
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto mask_cu = makeTransformerEngineTensor(mask);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_masked_softmax_forward(
input_cu.data(), mask_cu.data(), softmax_results_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads;
}
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_CHECK(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(),
softmax_results_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
TORCH_CHECK(output_grads.size(1) == output_grads.size(2));
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_upper_triang_masked_softmax_backward(output_grads_cu.data(),
softmax_results_cu.data(),
output_grads_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return output_grads;
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
void fused_cast_transpose(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
at::Tensor input_cast,
at::Tensor input_transpose,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(),
at::cuda::getCurrentCUDAStream());
}
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto grad_output_cast =
allocateTorchTensor(grad_output.size(0),
grad_output.size(1),
DType::kByte);
auto grad_output_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(),
scale.data_ptr(), scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(),
transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
return {grad_bias, grad_output_cast, grad_output_transpose};
}
std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
transformer_engine::DType grad_bias_type
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type);
auto grad_output_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(grad_output_transpose.data_ptr(),
{N, M}, otype, amax.data_ptr(),
scale.data_ptr(), scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
transformer_engine::TensorWrapper workspace;
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(),
workspace.data(), at::cuda::getCurrentCUDAStream());
return {grad_bias, grad_output_transpose};
}
std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
at::Tensor gelu_input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(grad_output.size(0));
size_t N = static_cast<size_t>(grad_output.size(1));
DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type());
auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type);
auto dgelu =
allocateTorchTensor(grad_output.size(0),
grad_output.size(1),
DType::kByte);
auto dgelu_transpose =
allocateTorchTensor(grad_output.size(1),
grad_output.size(0),
DType::kByte);
transformer_engine::TensorWrapper workspace;
auto gelu_input_cu = makeTransformerEngineTensor(gelu_input);
auto input_cu = makeTransformerEngineTensor(grad_output);
auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M},
otype, amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto dbias_cu = makeTransformerEngineTensor(grad_bias);
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
// Fill workspace
auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype());
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(),
cast_output_cu.data(), transposed_output_cu.data(),
dbias_cu.data(), workspace.data(),
at::cuda::getCurrentCUDAStream());
return {grad_bias, dgelu, dgelu_transpose};
}
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
std::vector<at::Tensor> transposed_output_list,
std::vector<at::Tensor> amax_list,
std::vector<at::Tensor> scale_inv_list,
transformer_engine::DType otype
) {
using namespace transformer_engine;
// Extract properties from PyTorch tensors
std::vector<void*> input_dptr_list, scale_dptr_list,
cast_output_dptr_list, transposed_output_dptr_list,
amax_dptr_list, scale_inv_dptr_list;
std::vector<std::vector<size_t>> input_shape_list, scale_shape_list,
cast_output_shape_list, transposed_output_shape_list,
amax_shape_list, scale_inv_shape_list;
std::vector<transformer_engine::DType> input_type_list, scale_type_list,
cast_output_type_list, transposed_output_type_list,
amax_type_list, scale_inv_type_list;
auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor,
std::vector<void*>& dptr_list,
std::vector<std::vector<size_t>>& shape_list) {
dptr_list.push_back(tensor.data_ptr());
shape_list.push_back({});
for (int d = 0; d < tensor.dim(); ++d) {
shape_list.back().push_back(tensor.size(d));
}
};
auto extract_tensor_props = [](at::Tensor& tensor,
std::vector<void*>& dptr_list,
std::vector<std::vector<size_t>>& shape_list,
std::vector<transformer_engine::DType>& type_list) {
dptr_list.push_back(tensor.data_ptr());
shape_list.push_back({});
for (int d = 0; d < tensor.dim(); ++d) {
shape_list.back().push_back(tensor.size(d));
}
type_list.push_back(GetTransformerEngineDType(tensor.scalar_type()));
};
for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) {
extract_tensor_props(input_list[tensor_id],
input_dptr_list,
input_shape_list,
input_type_list);
extract_tensor_props(scale_list[tensor_id],
scale_dptr_list,
scale_shape_list,
scale_type_list);
extract_tensor_props_skip_dtype(cast_output_list[tensor_id],
cast_output_dptr_list,
cast_output_shape_list);
cast_output_type_list.push_back(otype);
extract_tensor_props_skip_dtype(transposed_output_list[tensor_id],
transposed_output_dptr_list,
transposed_output_shape_list);
transposed_output_type_list.push_back(otype);
extract_tensor_props(amax_list[tensor_id],
amax_dptr_list,
amax_shape_list,
amax_type_list);
extract_tensor_props(scale_inv_list[tensor_id],
scale_inv_dptr_list,
scale_inv_shape_list,
scale_inv_type_list);
}
transformer_engine::TensorWrapper workspace;
// Construct TE tensors
std::vector<NVTETensor> nvte_input_list,
nvte_cast_output_list, nvte_transposed_output_list;
std::vector<transformer_engine::TensorWrapper> tensor_wrappers;
auto make_tensor = [&tensor_wrappers](void* dptr,
const std::vector<size_t>& shape,
transformer_engine::DType dtype,
void* amax_dptr,
void* scale_dptr,
void* scale_inv_dptr)
-> NVTETensor {
tensor_wrappers.emplace_back(makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr,
scale_dptr, scale_inv_dptr));
return tensor_wrappers.back().data();
};
for (size_t i = 0; i < input_dptr_list.size(); ++i) {
nvte_input_list.emplace_back(make_tensor(input_dptr_list[i],
input_shape_list[i],
input_type_list[i],
nullptr,
nullptr,
nullptr));
nvte_cast_output_list.emplace_back(make_tensor(cast_output_dptr_list[i],
cast_output_shape_list[i],
cast_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
nvte_transposed_output_list.emplace_back(make_tensor(transposed_output_dptr_list[i],
transposed_output_shape_list[i],
transposed_output_type_list[i],
amax_dptr_list[i],
scale_dptr_list[i],
scale_inv_dptr_list[i]));
}
// Check tensor lists
NVTE_CHECK(nvte_cast_output_list.size() == nvte_input_list.size(),
"Number of input and C output tensors must match");
NVTE_CHECK(nvte_transposed_output_list.size() == nvte_input_list.size(),
"Number of input and T output tensors must match");
// Launch TE kernel
nvte_multi_cast_transpose(nvte_input_list.size(),
nvte_input_list.data(),
nvte_cast_output_list.data(),
nvte_transposed_output_list.data(),
at::cuda::getCurrentCUDAStream());
}
at::Tensor fp8_transpose(at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
auto output =
allocateTorchTensor(input.size(1),
input.size(0),
DType::kByte);
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype);
nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return output;
}
...@@ -328,6 +328,44 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, ...@@ -328,6 +328,44 @@ at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input,
return output; return output;
} }
at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype,
const bool zero_centered_gamma) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
float eps_float = static_cast<float>(eps);
at::Tensor output = rmsnorm_fwd_fp8_inf(input,
weight,
eps_float,
scale,
amax,
scale_inv,
otype_arg,
zero_centered_gamma);
return output;
}
at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input,
const at::Tensor &weight,
double eps,
const bool zero_centered_gamma) {
float eps_float = static_cast<float>(eps);
at::Tensor output = rmsnorm_fwd_inf(input,
weight,
eps_float,
zero_centered_gamma);
return output;
}
TORCH_LIBRARY(tex_ts, m) { TORCH_LIBRARY(tex_ts, m) {
m.def("cast_to_fp8_ts", &cast_to_fp8_ts); m.def("cast_to_fp8_ts", &cast_to_fp8_ts);
m.def("cast_from_fp8_ts", &cast_from_fp8_ts); m.def("cast_from_fp8_ts", &cast_from_fp8_ts);
...@@ -339,4 +377,6 @@ TORCH_LIBRARY(tex_ts, m) { ...@@ -339,4 +377,6 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("te_gemm_ts", &te_gemm_ts); m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts);
m.def("rmsnorm_fwd_inf_ts", &rmsnorm_fwd_inf_ts);
} }
...@@ -7,3 +7,4 @@ from .layernorm_linear import LayerNormLinear ...@@ -7,3 +7,4 @@ from .layernorm_linear import LayerNormLinear
from .linear import Linear from .linear import Linear
from .layernorm_mlp import LayerNormMLP from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm from .layernorm import LayerNorm
from .rmsnorm import RMSNorm
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Internal function used by multiple modules."""
from typing import Union, Dict, Any
import torch
from .. import cpp_extensions as tex
from ..fp8 import get_fp8_te_dtype
def _get_normalization_func(normalization: str,
fp8_output: bool,
is_grad_enabled: bool,
forward: bool):
fwd_normalization_funcs = {
('LayerNorm', True, True): tex.layernorm_fwd_fp8,
('LayerNorm', True, False): tex.layernorm_fwd_fp8_inf,
('LayerNorm', False, True): tex.layernorm_fwd_noalloc,
('LayerNorm', False, False): tex.layernorm_fwd_inf,
('RMSNorm', True, True): tex.rmsnorm_fwd_fp8,
('RMSNorm', True, False): tex.rmsnorm_fwd_fp8_inf,
('RMSNorm', False, True): tex.rmsnorm_fwd_noalloc,
('RMSNorm', False, False): tex.rmsnorm_fwd_inf,
}
bwd_normalization_funcs = {
'LayerNorm': tex.layernorm_bwd,
'RMSNorm': tex.rmsnorm_bwd,
}
if forward:
return fwd_normalization_funcs[(normalization, fp8_output, is_grad_enabled)]
assert not fp8_output, "FP8 output is not supported in backward normalization!"
assert is_grad_enabled, "Gradient has to be enabled to call backward normalization!"
return bwd_normalization_funcs[normalization]
def _apply_normalization(inputmat:torch.Tensor,
ln_out: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: Union[torch.Tensor, None],
eps: float,
fp8_out: bool,
fp8_meta: Dict[str, Any],
normalization: str,
fwd_ln_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool):
normalization_func = _get_normalization_func(normalization,
fp8_out,
is_grad_enabled,
True)
inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)
if fp8_out:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if is_grad_enabled:
output_key = "ln_out" if normalization == "LayerNorm" else "rmsnorm_out"
output_kwarg = {output_key: ln_out}
output = normalization_func(
*inputs,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
fwd_ln_sm_margin,
zero_centered_gamma,
**output_kwarg,
)
else:
return normalization_func(
*inputs,
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
zero_centered_gamma,
), None, None
else:
if is_grad_enabled:
output = normalization_func(
*inputs, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
return normalization_func(
*inputs, eps, zero_centered_gamma
), None, None
if normalization == "RMSNorm":
output = (ln_out, None, output[1])
elif normalization == "LayerNorm":
output = (ln_out, output[1], output[2])
return output
...@@ -12,7 +12,7 @@ import torch ...@@ -12,7 +12,7 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import transformer_engine_extensions as tex from .. import cpp_extensions as tex
from .base import ( from .base import (
get_workspace, get_workspace,
...@@ -38,22 +38,13 @@ from ..distributed import ( ...@@ -38,22 +38,13 @@ from ..distributed import (
reduce_scatter_along_first_dim, reduce_scatter_along_first_dim,
gather_along_first_dim, gather_along_first_dim,
) )
from ..cpp_extensions import (
fp8_gemm,
gemm,
fp8_cast_transpose_fused,
layernorm_fwd_fp8,
layernorm_fwd_fp8_inf,
layernorm_fwd_inf,
cast_to_fp8,
cast_from_fp8,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ._common import _apply_normalization
__all__ = ["LayerNormLinear"]
__all__ = ["LayerNormLinear"]
class _LayerNormLinear(torch.autograd.Function): class _LayerNormLinear(torch.autograd.Function):
"""LayerNormLinear semi-top level module """LayerNormLinear semi-top level module
...@@ -65,7 +56,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -65,7 +56,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx, ctx,
inp: torch.Tensor, inp: torch.Tensor,
ln_weight: torch.Tensor, ln_weight: torch.Tensor,
ln_bias: torch.Tensor, ln_bias: Union[torch.Tensor, None],
weight: torch.Tensor, weight: torch.Tensor,
weight_fp8: Union[torch.Tensor, None], weight_fp8: Union[torch.Tensor, None],
weight_t_fp8: Union[torch.Tensor, None], weight_t_fp8: Union[torch.Tensor, None],
...@@ -91,6 +82,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -91,6 +82,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_split_ag: bool, ub_split_ag: bool,
normalization: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -105,10 +97,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -105,10 +97,9 @@ class _LayerNormLinear(torch.autograd.Function):
# Cast for native AMP # Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if ub_split_ag: if ub_split_ag:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
...@@ -118,69 +109,35 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -118,69 +109,35 @@ class _LayerNormLinear(torch.autograd.Function):
dim_size[0] = dim_size[0] * tp_world_size dim_size[0] = dim_size[0] * tp_world_size
ub_obj_lnout = get_ub("qkv_fprop") ub_obj_lnout = get_ub("qkv_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0) ln_out = ub_obj_lnout.get_ubuf_output(0)
if fp8: else:
ln_out_dtype = torch.uint8 if fp8 else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output: ln_out, mu, rsigma = _apply_normalization(inputmat,
if is_grad_enabled: ln_out,
if not ub_split_ag:
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = layernorm_fwd_fp8(
inputmat,
ln_weight, ln_weight,
ln_bias, ln_bias,
eps, eps,
fp8_meta["scaling_fwd"], fp8 and not return_layernorm_output,
tex.FP8FwdTensors.GEMM1_INPUT, fp8_meta,
fp8_dtype_forward, normalization,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
ln_out = ln_out is_grad_enabled)
) # If residual connection is after LN, we need `ln_out_return`
else: # tensor in higher precision, this comes at the cost
mu = rsigma = None # of an extra fp8 cast.
ln_out = layernorm_fwd_fp8_inf( if return_layernorm_output:
inputmat, ln_out_return = ln_out
ln_weight, if fp8:
ln_bias, ln_out = tex.cast_to_fp8(
eps, ln_out,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
zero_centered_gamma,
)
else:
if is_grad_enabled:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out_return, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out = cast_to_fp8(
ln_out_return,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
) )
else:
if is_grad_enabled:
if ub_split_ag:
_, mu, rsigma = tex.layernorm_fwd_noalloc(
inputmat, ln_weight, ln_bias, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out_return = ln_out
# Column Parallel Linear # Column Parallel Linear
if ub_split_ag: if ub_split_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
...@@ -200,7 +157,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -200,7 +157,7 @@ class _LayerNormLinear(torch.autograd.Function):
if update_fp8_weights: if update_fp8_weights:
if is_grad_enabled: if is_grad_enabled:
fp8_cast_transpose_fused( tex.fp8_cast_transpose_fused(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
...@@ -210,13 +167,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -210,13 +167,13 @@ class _LayerNormLinear(torch.autograd.Function):
) )
else: else:
weight_t_fp8 = None weight_t_fp8 = None
weight_fp8 = cast_to_fp8( weight_fp8 = tex.cast_to_fp8(
weight, weight,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward) fp8_dtype_forward)
out = fp8_gemm( out = tex.fp8_gemm(
weight_fp8, weight_fp8,
fp8_meta["scaling_fwd"].scale_inv, fp8_meta["scaling_fwd"].scale_inv,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
...@@ -247,7 +204,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -247,7 +204,7 @@ class _LayerNormLinear(torch.autograd.Function):
fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
torch.amax(weight).float() torch.amax(weight).float()
out, _, _ = gemm( out, _, _ = tex.gemm(
weight, weight,
ln_out_total, ln_out_total,
activation_dtype, activation_dtype,
...@@ -289,6 +246,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -289,6 +246,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
...@@ -379,7 +337,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -379,7 +337,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_ = fp8_gemm( _ = tex.fp8_gemm(
weight_t_fp8, weight_t_fp8,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_WEIGHT, tex.FP8FwdTensors.GEMM1_WEIGHT,
...@@ -397,7 +355,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -397,7 +355,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
else: else:
# DGRAD: Evaluated unconditionally to feed into Linear backward # DGRAD: Evaluated unconditionally to feed into Linear backward
_, _, _ = gemm( _, _, _ = tex.gemm(
weight, weight,
grad_output, grad_output,
ctx.activation_dtype, ctx.activation_dtype,
...@@ -427,7 +385,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -427,7 +385,7 @@ class _LayerNormLinear(torch.autograd.Function):
# WGRAD # WGRAD
if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
wgrad = fp8_gemm( wgrad = tex.fp8_gemm(
ln_out_total_t, ln_out_total_t,
fwd_scale_inverses, fwd_scale_inverses,
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
...@@ -446,14 +404,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -446,14 +404,14 @@ class _LayerNormLinear(torch.autograd.Function):
ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
) )
else: else:
ln_out_total_c = cast_from_fp8( ln_out_total_c = tex.cast_from_fp8(
ln_out_total, ln_out_total,
ctx.fp8_meta["scaling_fwd"], ctx.fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
TE_DType[ctx.activation_dtype], TE_DType[ctx.activation_dtype],
) )
wgrad, _, _ = gemm( wgrad, _, _ = tex.gemm(
ln_out_total_c, ln_out_total_c,
grad_output, grad_output,
ctx.activation_dtype, ctx.activation_dtype,
...@@ -468,7 +426,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -468,7 +426,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
else: else:
# WGRAD # WGRAD
wgrad, grad_bias, _ = gemm( wgrad, grad_bias, _ = tex.gemm(
ln_out_total, ln_out_total,
grad_output, grad_output,
ctx.activation_dtype, ctx.activation_dtype,
...@@ -496,10 +454,18 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -496,10 +454,18 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.return_layernorm_output: if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
if ctx.normalization == "LayerNorm":
dxmat, dgamma, dbeta = tex.layernorm_bwd( dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
elif ctx.normalization == "RMSNorm":
dxmat, dgamma = tex.rmsnorm_bwd(
d_ln_out, inputmat, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
dbeta = None
if not ctx.use_bias: if not ctx.use_bias:
grad_bias = None grad_bias = None
...@@ -533,6 +499,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -533,6 +499,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -555,6 +522,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -555,6 +522,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
a value added to the denominator of layer normalization for numerical stability. a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True` bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias. if set to `False`, the layer will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
init_method : Callable, default = `None` init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`. used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
...@@ -624,6 +593,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -624,6 +593,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
get_rng_state_tracker: Optional[Callable] = None, get_rng_state_tracker: Optional[Callable] = None,
init_method: Optional[Callable] = None, init_method: Optional[Callable] = None,
bias: bool = True, bias: bool = True,
normalization: str = 'LayerNorm',
return_bias: bool = False, return_bias: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None, parallel_mode: Optional[str] = None,
...@@ -649,9 +619,11 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -649,9 +619,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.normalization = normalization
assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!"
self.use_bias = bias self.use_bias = bias
self.return_bias = return_bias self.return_bias = return_bias
self.apply_bias = bias and not return_bias self.apply_bias = self.use_bias and not return_bias
self.return_layernorm_output = return_layernorm_output self.return_layernorm_output = return_layernorm_output
self.parameters_split = parameters_split self.parameters_split = parameters_split
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
...@@ -696,6 +668,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -696,6 +668,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
dtype=params_dtype, dtype=params_dtype,
) )
) )
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter( self.layer_norm_bias = Parameter(
torch.empty( torch.empty(
in_features, in_features,
...@@ -703,8 +677,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -703,8 +677,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
dtype=params_dtype, dtype=params_dtype,
) )
) )
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
else:
self.layer_norm_bias = None
self.reset_layer_norm_parameters() self.reset_layer_norm_parameters()
self.weight_tensor = torch.empty( self.weight_tensor = torch.empty(
...@@ -796,6 +771,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -796,6 +771,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
init.ones_(self.layer_norm_weight) init.ones_(self.layer_norm_weight)
else: else:
init.zeros_(self.layer_norm_weight) init.zeros_(self.layer_norm_weight)
if self.layer_norm_bias is not None:
init.zeros_(self.layer_norm_bias) init.zeros_(self.layer_norm_bias)
def get_fp8_weights_scratchpad( def get_fp8_weights_scratchpad(
...@@ -915,6 +891,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -915,6 +891,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_split_ag, self.ub_split_ag,
self.normalization,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
...@@ -46,6 +46,8 @@ from .. import cpp_extensions as tex ...@@ -46,6 +46,8 @@ from .. import cpp_extensions as tex
from ..constants import dist_group_type, TE_DType from ..constants import dist_group_type, TE_DType
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ._common import _apply_normalization
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -107,6 +109,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -107,6 +109,7 @@ class _LayerNormMLP(torch.autograd.Function):
ub_split_rs: bool, ub_split_rs: bool,
ub_split_ag: bool, ub_split_ag: bool,
activation: str, activation: str,
normalization: str,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features = ln_weight.numel() in_features = ln_weight.numel()
...@@ -124,6 +127,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -124,6 +127,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Cast for native AMP # Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
if ub_split_ag: if ub_split_ag:
...@@ -133,70 +137,39 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -133,70 +137,39 @@ class _LayerNormMLP(torch.autograd.Function):
if ub_split_ag: if ub_split_ag:
ub_obj_lnout = get_ub("fc1_fprop") ub_obj_lnout = get_ub("fc1_fprop")
ln_out = ub_obj_lnout.get_ubuf_output(0) ln_out = ub_obj_lnout.get_ubuf_output(0)
else:
ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype
ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype)
if ub_split_rs: if ub_split_rs:
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
if tp_world_size == 1: if tp_world_size == 1:
ub_split_rs = False ub_split_rs = False
# If residual connection is after LN, we need `ln_out`
# tensor in higher precision, this comes at the cost
# of an extra fp8 cast.
if fp8:
fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
if not return_layernorm_output:
if is_grad_enabled: ln_out, mu, rsigma = _apply_normalization(inputmat,
if not ub_split_ag: ln_out,
ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
_, mu, rsigma = tex.layernorm_fwd_fp8(
inputmat,
ln_weight, ln_weight,
ln_bias, ln_bias,
eps, eps,
fp8_meta["scaling_fwd"], fp8 and not return_layernorm_output,
tex.FP8FwdTensors.GEMM1_INPUT, fp8_meta,
fp8_dtype_forward, normalization,
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
ln_out = ln_out, is_grad_enabled)
) # If residual connection is after LN, we need `ln_out`
else: # tensor in higher precision, this comes at the cost
ln_out = tex.layernorm_fwd_fp8_inf( # of an extra fp8 cast.
inputmat, if return_layernorm_output:
ln_weight, ln_out_return = ln_out
ln_bias, if fp8:
eps,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward,
zero_centered_gamma,
)
else:
ln_out_return, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
ln_out = tex.cast_to_fp8( ln_out = tex.cast_to_fp8(
ln_out_return, ln_out,
fp8_meta["scaling_fwd"], fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_INPUT, tex.FP8FwdTensors.GEMM1_INPUT,
fp8_dtype_forward, fp8_dtype_forward,
) )
else:
if is_grad_enabled:
if ub_split_ag:
_, mu, rsigma = tex.layernorm_fwd_noalloc(
inputmat, ln_weight, ln_bias, ln_out, eps,
fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd(
inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
)
else:
ln_out, mu, rsigma = tex.layernorm_fwd_inf(
inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
), None, None
ln_out_return = ln_out
# Column Parallel Linear # Column Parallel Linear
if ub_split_ag: if ub_split_ag:
ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out_total = ub_obj_lnout.get_ubuf_output(1)
...@@ -422,6 +395,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -422,6 +395,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_split_ag = ub_split_ag ctx.ub_split_ag = ub_split_ag
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.normalization = normalization
# Row Parallel Linear # Row Parallel Linear
if ub_split_rs: if ub_split_rs:
...@@ -804,10 +778,17 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -804,10 +778,17 @@ class _LayerNormMLP(torch.autograd.Function):
if ctx.return_layernorm_output: if ctx.return_layernorm_output:
d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out) d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
if ctx.normalization == "LayerNorm":
dxmat, dgamma, dbeta = tex.layernorm_bwd( dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight, d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
) )
elif ctx.normalization == "RMSNorm":
dxmat, dgamma = tex.rmsnorm_bwd(
d_ln_out, inputmat, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
dbeta = None
return ( return (
dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None, dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
...@@ -846,6 +827,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -846,6 +827,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -864,6 +846,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -864,6 +846,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
a value added to the denominator of layer normalization for numerical stability. a value added to the denominator of layer normalization for numerical stability.
bias : bool, default = `True` bias : bool, default = `True`
if set to `False`, the FC1 and FC2 layers will not learn an additive bias. if set to `False`, the FC1 and FC2 layers will not learn an additive bias.
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
activation : str, default = 'gelu' activation : str, default = 'gelu'
activation function used. activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'. Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'.
...@@ -942,6 +926,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -942,6 +926,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
tp_size: int = 1, tp_size: int = 1,
init_method: Optional[Callable] = None, init_method: Optional[Callable] = None,
bias: bool = True, bias: bool = True,
normalization: str = 'LayerNorm',
activation : str = "gelu", activation : str = "gelu",
output_layer_init_method: Optional[Callable] = None, output_layer_init_method: Optional[Callable] = None,
fuse_wgrad_accumulation: bool = False, fuse_wgrad_accumulation: bool = False,
...@@ -960,6 +945,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -960,6 +945,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.normalization = normalization
assert normalization in ['LayerNorm', 'RMSNorm'], "Unsupported normalization type!"
self.use_bias = bias self.use_bias = bias
self.activation = activation self.activation = activation
self.return_bias = return_bias self.return_bias = return_bias
...@@ -1005,6 +992,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1005,6 +992,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
dtype=params_dtype, dtype=params_dtype,
) )
) )
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
if self.normalization != "RMSNorm":
self.layer_norm_bias = Parameter( self.layer_norm_bias = Parameter(
torch.empty( torch.empty(
hidden_size, hidden_size,
...@@ -1012,8 +1001,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1012,8 +1001,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
dtype=params_dtype, dtype=params_dtype,
) )
) )
setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
else:
self.layer_norm_bias = None
self.reset_layer_norm_parameters() self.reset_layer_norm_parameters()
if self.activation in ['reglu', 'geglu', 'swiglu']: if self.activation in ['reglu', 'geglu', 'swiglu']:
...@@ -1114,6 +1104,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1114,6 +1104,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
init.ones_(self.layer_norm_weight) init.ones_(self.layer_norm_weight)
else: else:
init.zeros_(self.layer_norm_weight) init.zeros_(self.layer_norm_weight)
if self.layer_norm_bias is not None:
init.zeros_(self.layer_norm_bias) init.zeros_(self.layer_norm_bias)
def get_fp8_weights_scratchpad( def get_fp8_weights_scratchpad(
...@@ -1217,6 +1208,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1217,6 +1208,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.ub_split_rs, self.ub_split_rs,
self.ub_split_ag, self.ub_split_ag,
self.activation, self.activation,
self.normalization,
) )
out = fwd_fn(*args) out = fwd_fn(*args)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""RMSNorm API"""
import os
from typing import Union, Tuple, Optional
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from .. import cpp_extensions as tex
from ..jit import no_torch_dynamo
__all__ = ["RMSNorm"]
class _RMSNorm(torch.autograd.Function):
"""functional RMSNorm"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
rmsnorm_weight: torch.Tensor,
eps: float,
fwd_rmsnorm_sm_margin: int,
bwd_rmsnorm_sm_margin: int,
zero_centered_gamma: bool,
is_grad_enabled: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = rmsnorm_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "RMSNorm not possible"
inputmat = inp.view((-1, in_features))
if is_grad_enabled:
rmsnorm_out, rsigma = tex.rmsnorm_fwd(inputmat, rmsnorm_weight,
eps, fwd_rmsnorm_sm_margin,
zero_centered_gamma)
ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
else:
rmsnorm_out = tex.rmsnorm_fwd_inf(inputmat, rmsnorm_weight,
eps,
zero_centered_gamma)
return rmsnorm_out.view_as(inp)
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_rmsnorm_out = grad_output.view(inputmat.shape)
dxmat, dgamma = tex.rmsnorm_bwd(
d_rmsnorm_out, inputmat, rsigma, rmsnorm_weight,
ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma
)
return (
dxmat.view(ctx.inp_shape),
dgamma,
None,
None,
None,
None,
None,
)
class RMSNorm(torch.nn.Module):
r"""
Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in
the paper `Root Mean Square Layer Normalization <https://arxiv.org/abs/1910.07467>`__
.. math::
y = \frac{x}{RMS(x) + \varepsilon} * \gamma
where
.. math::
RMS(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2}
:math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size`
Parameters
----------
hidden_size : int
size of each input sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in RMSNorm is initialized to 0 and
the RMSNorm formula changes to
.. math::
y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma)
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: Optional[torch.dtype] = None,
zero_centered_gamma: bool = False,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.weight, "sequence_parallel", sequence_parallel)
self.reset_rms_norm_parameters()
# These many SMs are subtracted from the total SM count when calling forward
# and backward RMSNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with RMSNorm.
self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def reset_rms_norm_parameters(self) -> None:
"""Init RMSNorm params"""
if not self.zero_centered_gamma:
init.ones_(self.weight)
else:
init.zeros_(self.weight)
@no_torch_dynamo
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""RMSNorm FWD"""
if torch.is_grad_enabled():
fwd_fn = _RMSNorm.apply
args = []
else:
fwd_fn = _RMSNorm.forward
args = [None]
args += (
inp,
self.weight,
self.eps,
self.fwd_rmsnorm_sm_margin,
self.bwd_rmsnorm_sm_margin,
self.zero_centered_gamma,
torch.is_grad_enabled()
)
return fwd_fn(*args)
...@@ -283,6 +283,20 @@ def onnx_te_gemm( ...@@ -283,6 +283,20 @@ def onnx_te_gemm(
return output return output
def _ones_like(g, inp, dtype):
"""Returns a tensor filled with the scalar value 1, with the same size as input and
with dtype data-type"""
shape = g.op("Shape", inp)
# WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR
# create a ConstantOfShape with type FP32 and then add a Cast to BF16.
is_bf16 = dtype == torch.bfloat16
one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1],
dtype=torch.float32 if is_bf16 else dtype))
if is_bf16:
one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return one
@symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b") @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v", "fs", "i", "i", "b")
def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax, def onnx_layernorm_fwd_fp8(g, inputs, weight, bias, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma): scale_inv, fp8_tensor, otype, zero_centered_gamma):
...@@ -305,19 +319,6 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): ...@@ -305,19 +319,6 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
"""ONNX graph for layernorm_fwd""" """ONNX graph for layernorm_fwd"""
# pylint: disable=unused-argument # pylint: disable=unused-argument
def ones_like(inp, dtype):
"""Returns a tensor filled with the scalar value 1, with the same size as input and
with dtype data-type"""
shape = g.op("Shape", inp)
# WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR
# create a ConstantOfShape with type FP32 and then add a Cast to BF16.
is_bf16 = dtype == torch.bfloat16
one = g.op("ConstantOfShape", shape, value_t=torch.tensor([1],
dtype=torch.float32 if is_bf16 else dtype))
if is_bf16:
one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16)
return one
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
if normalized_shape is None: if normalized_shape is None:
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs) ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs)
...@@ -328,7 +329,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): ...@@ -328,7 +329,7 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
if zero_centered_gamma: if zero_centered_gamma:
inputs_dtype = inputs.type().dtype() inputs_dtype = inputs.type().dtype()
one = ones_like(weight, inputs_dtype) one = _ones_like(g, weight, inputs_dtype)
weight = g.op("Add", weight, one) weight = g.op("Add", weight, one)
axis = -len(normalized_shape) axis = -len(normalized_shape)
...@@ -344,6 +345,57 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma): ...@@ -344,6 +345,57 @@ def onnx_layernorm_fwd(g, inputs, weight, bias, eps, zero_centered_gamma):
) )
return ln return ln
@symbolic_helper.parse_args("v", "v", "f", "v", "v", "fs", "i", "i", "b")
def onnx_rmsnorm_fwd_fp8(g, inputs, weight, eps, scale, amax,
scale_inv, fp8_tensor, otype, zero_centered_gamma):
"""ONNX graph for rmsnorm_fwd_fp8"""
# pylint: disable=unused-argument
inp_dtype = get_TensorProtoDataType(inputs)
if inp_dtype != get_TensorProtoDataType(weight):
weight = g.op("Cast", weight, to_i=inp_dtype)
ln = onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma)
fp8_ln = quantize(g, ln, scale_inv, fp8_tensor)
return fp8_ln
@symbolic_helper.parse_args("v", "v", "f", "b")
def onnx_rmsnorm_fwd(g, inputs, weight, eps, zero_centered_gamma):
"""ONNX graph for rmsnorm_fwd"""
# pylint: disable=unused-argument
normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs)
if normalized_shape is None:
ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs)
assert ndim is not None
normalized_shape = list(range(0, ndim))
# Normalization axis = 0, so normalized_shape uses all dims except dim = 0
normalized_shape = normalized_shape[1:]
if zero_centered_gamma:
inputs_dtype = inputs.type().dtype()
one = _ones_like(g, weight, inputs_dtype)
weight = g.op("Add", weight, one)
axis = -len(normalized_shape)
inputs_float = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT)
norm = g.op("ReduceL2", inputs_float, axes_i=[axis])
shape = g.op("Shape", inputs_float, start_i=-1)
shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT)
n_reciprocal = g.op("Reciprocal", shape_f)
sqrt_n_reciprocal = g.op("Sqrt", n_reciprocal)
rms = g.op("Mul", norm, sqrt_n_reciprocal)
eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32))
rms_eps = g.op("Add", rms, eps_tensor)
normalized_input = g.op("Div", inputs_float, rms_eps)
result = g.op("Mul", weight, normalized_input)
result = g.op("Cast", result, to_i=get_TensorProtoDataType(inputs))
return result
register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER) register_custom_op_symbolic('tex_ts::cast_to_fp8_ts', onnx_cast_to_fp8, VER)
register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER) register_custom_op_symbolic('tex_ts::cast_from_fp8_ts', onnx_cast_from_fp8, VER)
...@@ -355,3 +407,5 @@ register_custom_op_symbolic('tex_ts::swiglu_ts', onnx_fp8_swiglu, VER) ...@@ -355,3 +407,5 @@ register_custom_op_symbolic('tex_ts::swiglu_ts', onnx_fp8_swiglu, VER)
register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER) register_custom_op_symbolic('tex_ts::te_gemm_ts', onnx_te_gemm, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER) register_custom_op_symbolic('tex_ts::layernorm_fwd_fp8_inf_ts', onnx_layernorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER) register_custom_op_symbolic('tex_ts::layernorm_fwd_inf_ts', onnx_layernorm_fwd, VER)
register_custom_op_symbolic('tex_ts::rmsnorm_fwd_fp8_inf_ts', onnx_rmsnorm_fwd_fp8, VER)
register_custom_op_symbolic('tex_ts::rmsnorm_fwd_inf_ts', onnx_rmsnorm_fwd, VER)
...@@ -11,7 +11,7 @@ from typing import Any, Callable, Optional, Tuple, Union ...@@ -11,7 +11,7 @@ from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.pytorch.attention import MultiHeadAttention from transformer_engine.pytorch.attention import MultiHeadAttention
from transformer_engine.pytorch.jit import ( from transformer_engine.pytorch.jit import (
set_jit_fusion_options, set_jit_fusion_options,
...@@ -128,6 +128,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -128,6 +128,8 @@ class TransformerLayer(torch.nn.Module):
.. math:: .. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta (1 + \gamma) + \beta
normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm'
type of normalization applied.
qkv_weight_interleaved : bool, default = `True` qkv_weight_interleaved : bool, default = `True`
if set to `False`, the QKV weight is interpreted as a concatenation of if set to `False`, the QKV weight is interpreted as a concatenation of
query, key, and value weights along the `0th` dimension. The default query, key, and value weights along the `0th` dimension. The default
...@@ -220,7 +222,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -220,7 +222,8 @@ class TransformerLayer(torch.nn.Module):
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False, ub_tp_comm_overlap: bool = False,
bias: bool = True, bias: bool = True,
activation: str = 'gelu' activation: str = 'gelu',
normalization: str = "LayerNorm",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -312,6 +315,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -312,6 +315,7 @@ class TransformerLayer(torch.nn.Module):
input_layernorm=not output_layernorm, input_layernorm=not output_layernorm,
attention_type="self", attention_type="self",
bias=bias, bias=bias,
normalization=normalization,
) )
if layer_type == "decoder": if layer_type == "decoder":
...@@ -322,6 +326,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -322,6 +326,7 @@ class TransformerLayer(torch.nn.Module):
input_layernorm=True, input_layernorm=True,
attention_type="cross", attention_type="cross",
bias=bias, bias=bias,
normalization=normalization,
) )
# LayerNorm -> activation(Linear + Bias) -> Linear # LayerNorm -> activation(Linear + Bias) -> Linear
...@@ -353,6 +358,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -353,6 +358,7 @@ class TransformerLayer(torch.nn.Module):
ub_split_rs=ub_split_rs, ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag, ub_split_ag=ub_split_ag,
activation=activation, activation=activation,
normalization=normalization,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
...@@ -376,8 +382,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -376,8 +382,12 @@ class TransformerLayer(torch.nn.Module):
hidden_size, seq_length, micro_batch_size hidden_size, seq_length, micro_batch_size
) )
norm_module = {
"LayerNorm": LayerNorm,
"RMSNorm": RMSNorm,
}
if self.output_layernorm: if self.output_layernorm:
self.layernorm = LayerNorm( self.layernorm = norm_module[normalization](
hidden_size, hidden_size,
eps=layernorm_epsilon, eps=layernorm_epsilon,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment