Unverified Commit 4506a687 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

skip FastLayerNorm (#1305)

parent 199fa834
...@@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence ...@@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence
import torch import torch
from torch.autograd.variable import Variable from torch.autograd.variable import Variable
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
from apex.normalization.fused_layer_norm import FusedLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.enums import ModelType from apex.transformer.enums import ModelType
...@@ -122,7 +121,7 @@ def _calc_number_of_params(model: List[torch.nn.Module]) -> int: ...@@ -122,7 +121,7 @@ def _calc_number_of_params(model: List[torch.nn.Module]) -> int:
def _get_params_for_weight_decay_optimization( def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]], model: Union[torch.nn.Module, List[torch.nn.Module]],
*, *,
no_weight_decay_modules=(FastLayerNorm, FusedLayerNorm), no_weight_decay_modules=(FusedLayerNorm,),
) -> Dict[str, torch.nn.Parameter]: ) -> Dict[str, torch.nn.Parameter]:
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment