Unverified Commit 0e9b2771 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix LayerNorm API param names (#42)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e456110b
......@@ -6,7 +6,7 @@
import os
import warnings
from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, List, Any
from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping
from functools import partial
import torch
......@@ -2535,31 +2535,52 @@ class LayerNorm(torch.nn.Module):
) -> None:
super().__init__()
self.eps = eps
self.layer_norm_weight = Parameter(
self.weight = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
self.layer_norm_bias = Parameter(
self.bias = Parameter(
torch.empty(
hidden_size,
device=torch.cuda.current_device(),
dtype=params_dtype,
)
)
setattr(self.layer_norm_weight, "sequence_parallel", sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", sequence_parallel)
setattr(self.weight, "sequence_parallel", sequence_parallel)
setattr(self.bias, "sequence_parallel", sequence_parallel)
self.reset_layer_norm_parameters()
def load_state_dict(
self,
state_dict: Mapping[str, Any],
strict: bool = True,
) -> None:
"""Override PyTorch loader to maintain backward compatibility
with previous version of LayerNorm parameter names.
"""
if "layer_norm_weight" in state_dict:
state_dict["weight"] = state_dict["layer_norm_weight"]
del state_dict["layer_norm_weight"]
if "layer_norm_bias" in state_dict:
state_dict["bias"] = state_dict["layer_norm_bias"]
del state_dict["layer_norm_bias"]
super().load_state_dict(state_dict, strict)
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
init.ones_(self.layer_norm_weight)
init.zeros_(self.layer_norm_bias)
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD"""
return _LayerNorm.apply(
inp, self.layer_norm_weight, self.layer_norm_bias, self.eps
)
# Maintain backward compatibility.
if hasattr(self, "layer_norm_weight"):
setattr(self, "weight", self.layer_norm_weight)
if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias)
return _LayerNorm.apply(inp, self.weight, self.bias, self.eps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment