Commit 329fe582 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

working on Float16DistributedOptimizer

parent 7dc8c475
......@@ -168,6 +168,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
# >>>
# If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
if args.use_distributed_optimizer:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch':
args.use_contiguous_buffers_in_local_ddp = False
......@@ -700,6 +708,10 @@ def _add_distributed_args(parser):
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
# >>>
group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.')
# <<<
return parser
......
......@@ -19,9 +19,17 @@ from apex.optimizers import FusedSGD as SGD
from megatron import get_args
from megatron.model import LayerNorm
# >>>
from .distributed_fused_adam import DistributedFusedAdam
# <<<
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
# >>>
from .optimizer import (
Float16OptimizerWithFloat16Params,
Float16DistributedOptimizer,
FP32Optimizer,
)
# <<<
def get_param_groups(modules,
no_weight_decay_cond,
......@@ -97,7 +105,11 @@ def get_megatron_optimizer(model,
# })
# <<<
if args.optimizer == 'adam':
# >>>
if args.use_distributed_optimizer:
optimizer = DistributedFusedAdam(param_groups)
# <<<
elif args.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
......@@ -141,13 +153,18 @@ def get_megatron_optimizer(model,
hysteresis=args.hysteresis)
# Megatron optimizer.
return Float16OptimizerWithFloat16Params(optimizer,
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.bf16,
grad_scaler)
# >>>
opt_ty = Float16DistributedOptimizer \
if args.use_distributed_optimizer \
else Float16OptimizerWithFloat16Params
return opt_ty(optimizer,
args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp,
args.bf16,
grad_scaler)
# <<<
# FP32.
return FP32Optimizer(optimizer, args.clip_grad,
......
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
from megatron import mpu
# >>>
from lutil import pax, tp
# <<<
class DistributedFusedAdam(torch.optim.Optimizer):
def __init__(self, params):
super().__init__(params, defaults = {})
self.initialized = False
# self.params_32 = None
# self.grads_32 = None
# self.opt_m = None
# self.opt_v = None
# pax(0, {
# "param_groups" : self.param_groups,
# "param_groups / 0" : self.param_groups[0],
# "param_groups / 1" : self.param_groups[1],
# "param_groups / 0 / params" : self.param_groups[0]["params"],
# # "param_groups / params" : [ g["params"] for g in self.param_groups ],
# })
def initialize(self):
if self.initialized:
raise Exception("initialization worked.")
return
self.initialized = True
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
total_param_size = sum(
p.numel()
for g in self.param_groups
for p in g["params"]
)
shard_size = int(math.ceil(total_param_size / data_parallel_world_size))
shard_start_index = data_parallel_rank * shard_size
shard_end_index = min(total_param_size, shard_start_index + shard_size)
shard_size = shard_end_index - shard_start_index
allocate_shard = lambda dtype : torch.empty(
[shard_size],
dtype = dtype,
device = torch.cuda.current_device())
self.main_param_shard = allocate_shard(torch.float)
self.main_grad_shard = allocate_shard(torch.float)
self.adam_m_shard = allocate_shard(torch.float)
self.adam_v_shard = allocate_shard(torch.float)
# pax(2, {
# "data_parallel_rank" : data_parallel_rank,
# "data_parallel_world_size" : data_parallel_world_size,
# "total_param_size" : total_param_size,
# "shard_size" : shard_size,
# "shard" : "%d [ %d, %d ]" % (
# shard_size,
# shard_start_index,
# shard_end_index,
# ),
# })
def step(self):
self.initialize()
raise Exception("what's next?")
# >>>
# eof
# <<<
......@@ -275,6 +275,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# <<<
# >>>
# debug()
# from lutil import pax, tp
# pax(0, {
# "param" : tp(param),
# "main_param" : tp(main_param),
# })
# <<<
fp32_from_float16_params_this_group.append(main_param)
# Reset existing state dict key to the new main param.
......@@ -354,6 +360,84 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
return self.grad_scaler.scale
# >>>
def reduce_gradientss(self):
# >>>
# if not args.use_distributed_optimizer:
# All-reduce if needed.
# >>>
# if args.DDP_impl == 'local' and not args.use_distributed_optimizer:
if args.DDP_impl == 'local':
# <<<
timers('backward-params-all-reduce').start()
for model_module in model:
# >>>
# from lutil import pax, tp
# pax(0, {
# "model" : model,
# "model_module" : model_module,
# })
# <<<
# >>>
# e.g., grad_shard = optimizer.get_grad_shard()
# <<<
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>>
# if args.DDP_impl == 'local':
# grad = word_embeddings_weight.main_grad
# else:
# grad = word_embeddings_weight.grad
# torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
grad_shard = optimizer.get_grad_shard(word_embeddings)
torch.distributed.all_reduce(grad_shard,
group=mpu.get_embedding_group())
# <<<
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
# >>>
# grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
# torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++
grad_shard = optimizer.get_grad_shard(
unwrapped_model.language_model.embedding.position_embeddings.weight)
torch.distributed.all_reduce(grad_shard,
group=mpu.get_position_embedding_group())
# <<<
timers('backward-embedding-all-reduce').stop()
def _copy_model_grads_to_main_grads(self):
# This only needs to be done for the float16 group.
for model_group, main_group in zip(self.float16_groups,
......@@ -542,6 +626,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
current_param.data.copy_(saved_param.data)
# >>>
class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
def step(self):
raise Exception("hi.")
# <<<
class FP32Optimizer(MegatronOptimizer):
......
......@@ -410,60 +410,30 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer()
optimizer.zero_grad()
# >>>
# Forward pass.
# <<<
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# Empty unused memory
# >>>
# Empty unused memory.
# <<<
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('backward-params-all-reduce').start()
for model_module in model:
model_module.allreduce_gradients()
timers('backward-params-all-reduce').stop()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if mpu.is_rank_in_position_embedding_group() and \
mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.pipeline_model_parallel_split_rank is not None:
unwrapped_model = model[0]
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, Float16Module))
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
timers('backward-embedding-all-reduce').stop()
# >>>
# Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients)
optimizer.reduce_gradients()
# <<<
# >>>
from lutil import pax
pax({"optimizer": optimizer})
# <<<
# Update parameters.
timers('optimizer').start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment