Unverified Commit 199fa834 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

be more flexible (#1299)

parent 069ff336
...@@ -3,6 +3,8 @@ from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence ...@@ -3,6 +3,8 @@ from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence
import torch import torch
from torch.autograd.variable import Variable from torch.autograd.variable import Variable
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
from apex.normalization.fused_layer_norm import FusedLayerNorm
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.enums import ModelType from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_num_microbatches
...@@ -119,18 +121,19 @@ def _calc_number_of_params(model: List[torch.nn.Module]) -> int: ...@@ -119,18 +121,19 @@ def _calc_number_of_params(model: List[torch.nn.Module]) -> int:
def _get_params_for_weight_decay_optimization( def _get_params_for_weight_decay_optimization(
model: Union[torch.nn.Module, List[torch.nn.Module]], model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
no_weight_decay_modules=(FastLayerNorm, FusedLayerNorm),
) -> Dict[str, torch.nn.Parameter]: ) -> Dict[str, torch.nn.Parameter]:
"""Divide params into with-weight-decay and without-weight-decay groups. """Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will. Layernorms and biases will have no weight decay but the rest will.
""" """
modules = listify_model(model) modules = listify_model(model)
from apex.normalization.fused_layer_norm import FusedLayerNorm # NOQA
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module in modules: for module in modules:
for module_ in module.modules(): for module_ in module.modules():
if isinstance(module_, FusedLayerNorm): if isinstance(module_, no_weight_decay_modules):
no_weight_decay_params['params'].extend( no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values()) [p for p in list(module_._parameters.values())
if p is not None]) if p is not None])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment