Commit 329fe582 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working on Float16DistributedOptimizer

parent 7dc8c475
...@@ -168,6 +168,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -168,6 +168,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
# >>>
# If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
if args.use_distributed_optimizer:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# For torch DDP, we do not use contiguous buffer # For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
args.use_contiguous_buffers_in_local_ddp = False args.use_contiguous_buffers_in_local_ddp = False
...@@ -700,6 +708,10 @@ def _add_distributed_args(parser): ...@@ -700,6 +708,10 @@ def _add_distributed_args(parser):
help='Call torch.cuda.empty_cache() each iteration ' help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.' '(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.') '0=off, 1=moderate, 2=aggressive.')
# >>>
group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.')
# <<<
return parser return parser
......
...@@ -19,9 +19,17 @@ from apex.optimizers import FusedSGD as SGD ...@@ -19,9 +19,17 @@ from apex.optimizers import FusedSGD as SGD
from megatron import get_args from megatron import get_args
from megatron.model import LayerNorm from megatron.model import LayerNorm
# >>>
from .distributed_fused_adam import DistributedFusedAdam
# <<<
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer # >>>
from .optimizer import (
Float16OptimizerWithFloat16Params,
Float16DistributedOptimizer,
FP32Optimizer,
)
# <<<
def get_param_groups(modules, def get_param_groups(modules,
no_weight_decay_cond, no_weight_decay_cond,
...@@ -97,7 +105,11 @@ def get_megatron_optimizer(model, ...@@ -97,7 +105,11 @@ def get_megatron_optimizer(model,
# }) # })
# <<< # <<<
if args.optimizer == 'adam': # >>>
if args.use_distributed_optimizer:
optimizer = DistributedFusedAdam(param_groups)
# <<<
elif args.optimizer == 'adam':
optimizer = Adam(param_groups, optimizer = Adam(param_groups,
lr=args.lr, lr=args.lr,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
...@@ -141,13 +153,18 @@ def get_megatron_optimizer(model, ...@@ -141,13 +153,18 @@ def get_megatron_optimizer(model,
hysteresis=args.hysteresis) hysteresis=args.hysteresis)
# Megatron optimizer. # Megatron optimizer.
return Float16OptimizerWithFloat16Params(optimizer, # >>>
args.clip_grad, opt_ty = Float16DistributedOptimizer \
args.log_num_zeros_in_grad, if args.use_distributed_optimizer \
params_have_main_grad, else Float16OptimizerWithFloat16Params
args.use_contiguous_buffers_in_local_ddp, return opt_ty(optimizer,
args.bf16, args.clip_grad,
grad_scaler) args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.bf16,
grad_scaler)
# <<<
# FP32. # FP32.
return FP32Optimizer(optimizer, args.clip_grad, return FP32Optimizer(optimizer, args.clip_grad,
......
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from megatron import mpu
# >>>
from lutil import pax, tp
# <<<
class DistributedFusedAdam(torch.optim.Optimizer):
def __init__(self, params):
super().__init__(params, defaults = {})
self.initialized = False
# self.params_32 = None
# self.grads_32 = None
# self.opt_m = None
# self.opt_v = None
# pax(0, {
# "param_groups" : self.param_groups,
# "param_groups / 0" : self.param_groups[0],
# "param_groups / 1" : self.param_groups[1],
# "param_groups / 0 / params" : self.param_groups[0]["params"],
# # "param_groups / params" : [ g["params"] for g in self.param_groups ],
# })
def initialize(self):
if self.initialized:
raise Exception("initialization worked.")
return
self.initialized = True
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
total_param_size = sum(
p.numel()
for g in self.param_groups
for p in g["params"]
)
shard_size = int(math.ceil(total_param_size / data_parallel_world_size))
shard_start_index = data_parallel_rank * shard_size
shard_end_index = min(total_param_size, shard_start_index + shard_size)
shard_size = shard_end_index - shard_start_index
allocate_shard = lambda dtype : torch.empty(
[shard_size],
dtype = dtype,
device = torch.cuda.current_device())
self.main_param_shard = allocate_shard(torch.float)
self.main_grad_shard = allocate_shard(torch.float)
self.adam_m_shard = allocate_shard(torch.float)
self.adam_v_shard = allocate_shard(torch.float)
# pax(2, {
# "data_parallel_rank" : data_parallel_rank,
# "data_parallel_world_size" : data_parallel_world_size,
# "total_param_size" : total_param_size,
# "shard_size" : shard_size,
# "shard" : "%d [ %d, %d ]" % (
# shard_size,
# shard_start_index,
# shard_end_index,
# ),
# })
def step(self):
self.initialize()
raise Exception("what's next?")
# >>>
# eof
# <<<
...@@ -275,6 +275,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -275,6 +275,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# <<< # <<<
# >>> # >>>
# debug() # debug()
# from lutil import pax, tp
# pax(0, {
# "param" : tp(param),
# "main_param" : tp(main_param),
# })
# <<< # <<<
fp32_from_float16_params_this_group.append(main_param) fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param. # Reset existing state dict key to the new main param.
...@@ -354,6 +360,84 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -354,6 +360,84 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
return self.grad_scaler.scale return self.grad_scaler.scale
# >>>
def reduce_gradientss(self):
# >>>
# if not args.use_distributed_optimizer:
# All-reduce if needed.
# >>>
# if args.DDP_impl == 'local' and not args.use_distributed_optimizer:
if args.DDP_impl == 'local':
# <<<
timers('backward-params-all-reduce').start()
for model_module in model:
# >>>
# from lutil import pax, tp
# pax(0, {
# "model" : model,
# "model_module" : model_module,
# })
# <<<
# >>>
# e.g., grad_shard = optimizer.get_grad_shard()
# <<<
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>>
# if args.DDP_impl == 'local':
# grad = word_embeddings_weight.main_grad
# else:
# grad = word_embeddings_weight.grad
# torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
grad_shard = optimizer.get_grad_shard(word_embeddings)
torch.distributed.all_reduce(grad_shard,
group=mpu.get_embedding_group())
# <<<
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
# >>>
# grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
# torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++
grad_shard = optimizer.get_grad_shard(
unwrapped_model.language_model.embedding.position_embeddings.weight)
torch.distributed.all_reduce(grad_shard,
group=mpu.get_position_embedding_group())
# <<<
timers('backward-embedding-all-reduce').stop()
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the float16 group. # This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups, for model_group, main_group in zip(self.float16_groups,
...@@ -542,6 +626,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -542,6 +626,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
current_param.data.copy_(saved_param.data) current_param.data.copy_(saved_param.data)
# >>>
class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
def step(self):
raise Exception("hi.")
# <<<
class FP32Optimizer(MegatronOptimizer): class FP32Optimizer(MegatronOptimizer):
......
...@@ -410,60 +410,30 @@ def train_step(forward_step_func, data_iterator, ...@@ -410,60 +410,30 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer() partition.zero_grad_buffer()
optimizer.zero_grad() optimizer.zero_grad()
# >>>
# Forward pass.
# <<<
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False) optimizer, timers, forward_only=False)
# Empty unused memory # >>>
# Empty unused memory.
# <<<
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# All-reduce if needed. # >>>
if args.DDP_impl == 'local': # Reduce gradients. (with distributed optimizer option, optimizer
timers('backward-params-all-reduce').start() # now responsible for reducing gradients)
for model_module in model: optimizer.reduce_gradients()
model_module.allreduce_gradients() # <<<
timers('backward-params-all-reduce').stop()
# >>>
# All-reduce word_embeddings' grad across first and last stages to ensure from lutil import pax
# that word_embeddings parameters stay in sync. pax({"optimizer": optimizer})
# This should only run for models that support pipelined model parallelism # <<<
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment