Unverified Commit c6a4a4e0 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

PyTorch refactor (#201)



* Initial refactor
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* refactor attention out of transformer.py
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix ONNX export
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f7608d89
...@@ -34,7 +34,7 @@ import transformer_engine.pytorch as te ...@@ -34,7 +34,7 @@ import transformer_engine.pytorch as te
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, fp8_gelu, cast_to_fp8, cast_from_fp8 from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, fp8_gelu, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module import get_workspace from transformer_engine.pytorch.module.base import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs import transformer_engine.pytorch.softmax as softmax_defs
from transformer_engine.pytorch.utils import get_default_init_method from transformer_engine.pytorch.utils import get_default_init_method
...@@ -882,7 +882,7 @@ def test_export_core_attention( ...@@ -882,7 +882,7 @@ def test_export_core_attention(
if attn_mask_type is None: if attn_mask_type is None:
attn_mask_type = 'causal' attn_mask_type = 'causal'
inp = (query_layer, key_layer, value_layer) inp = (query_layer, key_layer, value_layer)
model = te.transformer.DotProductAttention( model = te.attention.DotProductAttention(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
kv_channels=kv_channels, kv_channels=kv_channels,
attention_dropout=0.5, attention_dropout=0.5,
...@@ -972,7 +972,7 @@ def test_export_multihead_attention( ...@@ -972,7 +972,7 @@ def test_export_multihead_attention(
input_ln_str = "_input-ln" if input_layernorm else "" input_ln_str = "_input-ln" if input_layernorm else ""
fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx" fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"
model = te.transformer.MultiHeadAttention( model = te.attention.MultiHeadAttention(
*attention_args, *attention_args,
attn_mask_type=attn_mask_type, attn_mask_type=attn_mask_type,
params_dtype=precision, params_dtype=precision,
......
...@@ -17,8 +17,8 @@ import torch ...@@ -17,8 +17,8 @@ import torch
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine_extensions as tex import transformer_engine_extensions as tex
from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8, cast_from_fp8 from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8, cast_from_fp8
from transformer_engine.pytorch.module import get_workspace from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.module import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
def init_meta(size: int=1): def init_meta(size: int=1):
......
...@@ -7,7 +7,7 @@ from .module import LayerNormLinear ...@@ -7,7 +7,7 @@ from .module import LayerNormLinear
from .module import Linear from .module import Linear
from .module import LayerNormMLP from .module import LayerNormMLP
from .module import LayerNorm from .module import LayerNorm
from .transformer import DotProductAttention from .attention import DotProductAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .export import onnx_export from .export import onnx_export
......
This diff is collapsed.
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Module level PyTorch APIs"""
from .layernorm_linear import LayerNormLinear
from .linear import Linear
from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm
This diff is collapsed.
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LayerNorm API"""
import os
from typing import Union, Tuple, Any, Mapping
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import transformer_engine_extensions as tex
__all__ = ["LayerNorm"]
class _LayerNorm(torch.autograd.Function):
"""functional LayerNorm"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
ln_weight: torch.Tensor,
ln_bias: torch.Tensor,
eps: float,
fwd_ln_sm_margin: int,
bwd_ln_sm_margin: int,
zero_centered_gamma: bool,
) -> torch.Tensor:
# Make sure input dimensions are compatible
in_features = ln_weight.numel()
assert inp.is_cuda, "TransformerEngine needs CUDA."
assert inp.shape[-1] == in_features, "LayerNorm not possible"
inputmat = inp.view((-1, in_features))
ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight,
ln_bias, eps, fwd_ln_sm_margin,
zero_centered_gamma)
ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
ctx.inp_shape = inp.shape
ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
ctx.zero_centered_gamma = zero_centered_gamma
return ln_out.view_as(inp)
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
grad_output = grad_output.contiguous()
d_ln_out = grad_output.view(inputmat.shape)
dxmat, dgamma, dbeta = tex.layernorm_bwd(
d_ln_out, inputmat, mu, rsigma, ln_weight,
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None
class LayerNorm(torch.nn.Module):
r"""
Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
size :attr:`hidden_size`
Parameters
----------
hidden_size : int
size of each input sample.
eps : float, default = 1e-5
a value added to the denominator of layer normalization for numerical stability.
sequence_parallel : bool, default = `False`
if set to `True`, uses sequence parallelism.
params_dtype : torch.dtype, default = `torch.float32`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
zero_centered_gamma : bool, default = 'False'
if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
the LayerNorm formula changes to
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
(1 + \gamma) + \beta
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-5,
sequence_parallel: bool = False,
params_dtype: torch.dtype = torch.float32,
zero_centered_gamma: bool = False,
) -> None:
super().__init__()
self.eps = eps
self.zero_centered_gamma = zero_centered_gamma
self.weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.bias = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.weight, "sequence_parallel", sequence_parallel)
setattr(self.bias, "sequence_parallel", sequence_parallel)
self.reset_layer_norm_parameters()
# These many SMs are subtracted from the total SM count when calling forward
# and backward LayerNorm C APIs. These envvars can be used to prevent the LN
# kernels from using all SMs in the device. This is useful for cases such as
# communication overlap with LN.
self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))
def load_state_dict(
self,
state_dict: Mapping[str, Any],
strict: bool = True,
) -> None:
"""Override PyTorch loader to maintain backward compatibility
with previous version of LayerNorm parameter names.
"""
if "layer_norm_weight" in state_dict:
state_dict["weight"] = state_dict["layer_norm_weight"]
del state_dict["layer_norm_weight"]
if "layer_norm_bias" in state_dict:
state_dict["bias"] = state_dict["layer_norm_bias"]
del state_dict["layer_norm_bias"]
super().load_state_dict(state_dict, strict)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
if not self.zero_centered_gamma:
init.ones_(self.weight)
else:
init.zeros_(self.weight)
init.zeros_(self.bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
# Maintain backward compatibility.
if hasattr(self, "layer_norm_weight"):
setattr(self, "weight", self.layer_norm_weight)
if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias)
return _LayerNorm.apply(
inp,
self.weight,
self.bias,
self.eps,
self.fwd_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment