"vscode:/vscode.git/clone" did not exist on "358ba459d2f05321a15555c3b57c606fe2597ec7"
Unverified Commit 0e9b2771 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix LayerNorm API param names (#42)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e456110b
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import os import os
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union, Optional, Callable, Tuple, Dict, List, Any from typing import Union, Optional, Callable, Tuple, Dict, List, Any, Mapping
from functools import partial from functools import partial
import torch import torch
...@@ -2535,31 +2535,52 @@ class LayerNorm(torch.nn.Module): ...@@ -2535,31 +2535,52 @@ class LayerNorm(torch.nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.layer_norm_weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
hidden_size, hidden_size,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=params_dtype, dtype=params_dtype,
) )
) )
self.layer_norm_bias = Parameter( self.bias = Parameter(
torch.empty( torch.empty(
hidden_size, hidden_size,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=params_dtype, dtype=params_dtype,
) )
) )
setattr(self.layer_norm_weight, "sequence_parallel", sequence_parallel) setattr(self.weight, "sequence_parallel", sequence_parallel)
setattr(self.layer_norm_bias, "sequence_parallel", sequence_parallel) setattr(self.bias, "sequence_parallel", sequence_parallel)
self.reset_layer_norm_parameters() self.reset_layer_norm_parameters()
def load_state_dict(
self,
state_dict: Mapping[str, Any],
strict: bool = True,
) -> None:
"""Override PyTorch loader to maintain backward compatibility
with previous version of LayerNorm parameter names.
"""
if "layer_norm_weight" in state_dict:
state_dict["weight"] = state_dict["layer_norm_weight"]
del state_dict["layer_norm_weight"]
if "layer_norm_bias" in state_dict:
state_dict["bias"] = state_dict["layer_norm_bias"]
del state_dict["layer_norm_bias"]
super().load_state_dict(state_dict, strict)
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
init.ones_(self.layer_norm_weight) init.ones_(self.weight)
init.zeros_(self.layer_norm_bias) init.zeros_(self.bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor: def forward(self, inp: torch.Tensor) -> torch.Tensor:
"""LayerNorm FWD""" """LayerNorm FWD"""
return _LayerNorm.apply( # Maintain backward compatibility.
inp, self.layer_norm_weight, self.layer_norm_bias, self.eps if hasattr(self, "layer_norm_weight"):
) setattr(self, "weight", self.layer_norm_weight)
if hasattr(self, "layer_norm_bias"):
setattr(self, "bias", self.layer_norm_bias)
return _LayerNorm.apply(inp, self.weight, self.bias, self.eps)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment