Commit 28062e14 authored by mohammad's avatar mohammad
Browse files

moved entire optimizer build and tested

parent fb218c9d
...@@ -33,7 +33,6 @@ from .distributed import * ...@@ -33,7 +33,6 @@ from .distributed import *
from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage from .bert_model import BertModel, BertModelFirstStage, BertModelIntermediateStage, BertModelLastStage
from .realm_model import ICTBertModel from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage from .gpt2_model import GPT2Model, GPT2ModelFirstStage, GPT2ModelIntermediateStage, GPT2ModelLastStage
from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model from .language_model import get_language_model
...@@ -20,7 +20,6 @@ import math ...@@ -20,7 +20,6 @@ import math
import torch import torch
from megatron import get_args from megatron import get_args
from megatron.model import import_layernorm
def init_method_normal(sigma): def init_method_normal(sigma):
"""Init method based on N(0, sigma).""" """Init method based on N(0, sigma)."""
...@@ -60,28 +59,3 @@ def openai_gelu(x): ...@@ -60,28 +59,3 @@ def openai_gelu(x):
@torch.jit.script @torch.jit.script
def erf_gelu(x): def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
...@@ -21,16 +21,49 @@ from abc import abstractmethod ...@@ -21,16 +21,49 @@ from abc import abstractmethod
import torch import torch
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
from apex.optimizers import FusedAdam as Adam
import amp_C import amp_C
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.model import import_layernorm
def get_megatron_optimizer(optimizer, model): def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args() args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
def get_megatron_optimizer(model):
args = get_args()
# Base optimizer.
param_groups = get_params_for_weight_decay_optimization(model)
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
if args.fp16: if args.fp16:
# Constant loss scale. # Constant loss scale.
......
...@@ -24,7 +24,6 @@ _TRAIN_START_TIME = time.time() ...@@ -24,7 +24,6 @@ _TRAIN_START_TIME = time.time()
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
...@@ -45,7 +44,6 @@ from megatron.initialize import initialize_megatron ...@@ -45,7 +44,6 @@ from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader from megatron.data.data_loaders import build_pretraining_data_loader
...@@ -184,6 +182,10 @@ def get_model(model_provider_func): ...@@ -184,6 +182,10 @@ def get_model(model_provider_func):
# Build model on cpu. # Build model on cpu.
model = model_provider_func() model = model_provider_func()
# Set tensor model parallel attributes if not set.
for param in model.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) ' print(' > number of parameters on (tensor, pipeline) '
...@@ -212,30 +214,6 @@ def get_model(model_provider_func): ...@@ -212,30 +214,6 @@ def get_model(model_provider_func):
'Exiting.'.format(args.DDP_impl)) 'Exiting.'.format(args.DDP_impl))
def get_optimizer(model):
"""Set up the optimizer."""
args = get_args()
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
model = model.module
param_groups = get_params_for_weight_decay_optimization(model)
# Add model parallel attribute if it is not set.
for param_group in param_groups:
for param in param_group['params']:
if not hasattr(param, 'tensor_model_parallel'):
param.tensor_model_parallel = False
# Use Adam.
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
# Wrap into fp16 optimizer.
optimizer = get_megatron_optimizer(optimizer, model)
return optimizer
def get_learning_rate_scheduler(optimizer): def get_learning_rate_scheduler(optimizer):
"""Build the learning rate scheduler.""" """Build the learning rate scheduler."""
args = get_args() args = get_args()
...@@ -284,7 +262,12 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -284,7 +262,12 @@ def setup_model_and_optimizer(model_provider_func):
args = get_args() args = get_args()
model = get_model(model_provider_func) model = get_model(model_provider_func)
optimizer = get_optimizer(model)
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
unwrapped_model = unwrapped_model.module
optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer) lr_scheduler = get_learning_rate_scheduler(optimizer)
if args.load is not None: if args.load is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment