Commit 49cca4d9 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

more work on Float16DistributedOptimizer

parent 329fe582
...@@ -20,7 +20,7 @@ from megatron import get_args ...@@ -20,7 +20,7 @@ from megatron import get_args
from megatron.model import LayerNorm from megatron.model import LayerNorm
# >>> # >>>
from .distributed_fused_adam import DistributedFusedAdam # from .distributed_fused_adam import DistributedFusedAdam
# <<< # <<<
from .grad_scaler import ConstantGradScaler, DynamicGradScaler from .grad_scaler import ConstantGradScaler, DynamicGradScaler
# >>> # >>>
...@@ -106,10 +106,11 @@ def get_megatron_optimizer(model, ...@@ -106,10 +106,11 @@ def get_megatron_optimizer(model,
# <<< # <<<
# >>> # >>>
if args.use_distributed_optimizer: # if args.use_distributed_optimizer:
optimizer = DistributedFusedAdam(param_groups) # optimizer = DistributedFusedAdam(param_groups)
# elif args.optimizer == 'adam':
# <<< # <<<
elif args.optimizer == 'adam': if args.optimizer == 'adam':
optimizer = Adam(param_groups, optimizer = Adam(param_groups,
lr=args.lr, lr=args.lr,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
...@@ -167,7 +168,12 @@ def get_megatron_optimizer(model, ...@@ -167,7 +168,12 @@ def get_megatron_optimizer(model,
# <<< # <<<
# FP32. # FP32.
return FP32Optimizer(optimizer, args.clip_grad, # >>>
args.log_num_zeros_in_grad, opt_ty = Float32DistributedOptimizer \
params_have_main_grad, if args.use_distributed_optimizer \
args.use_contiguous_buffers_in_local_ddp) else Float32Optimizer
return opt_ty(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp)
# <<<
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
? ? ?
import math import math
import torch import torch
......
...@@ -29,6 +29,9 @@ from megatron import print_rank_0 ...@@ -29,6 +29,9 @@ from megatron import print_rank_0
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from lutil import pax, tp
# <<<
def _zero_grad_group_helper(group, set_to_none): def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters. """Zero out the gradient for a group of parameters.
...@@ -361,7 +364,20 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -361,7 +364,20 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# >>> # >>>
def reduce_gradientss(self): def reduce_gradients(self, model):
# >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
args = get_args()
timers = get_timers()
# <<<
# >>> # >>>
# if not args.use_distributed_optimizer: # if not args.use_distributed_optimizer:
...@@ -405,15 +421,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -405,15 +421,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if unwrapped_model.share_word_embeddings: if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight() word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>> # >>>
# if args.DDP_impl == 'local': if args.DDP_impl == 'local':
# grad = word_embeddings_weight.main_grad grad = word_embeddings_weight.main_grad
# else: else:
# grad = word_embeddings_weight.grad grad = word_embeddings_weight.grad
# torch.distributed.all_reduce(grad, group=mpu.get_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++ # +++
grad_shard = optimizer.get_grad_shard(word_embeddings) # grad_shard = optimizer.get_grad_shard(word_embeddings)
torch.distributed.all_reduce(grad_shard, # torch.distributed.all_reduce(grad_shard,
group=mpu.get_embedding_group()) # group=mpu.get_embedding_group())
# <<< # <<<
# All-reduce position_embeddings grad across first (encoder) and split (decoder) # All-reduce position_embeddings grad across first (encoder) and split (decoder)
...@@ -428,13 +444,13 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -428,13 +444,13 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
assert args.DDP_impl == 'local', \ assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode' 'T5 model is only supported with local DDP mode'
# >>> # >>>
# grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
# torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group()) torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++ # +++
grad_shard = optimizer.get_grad_shard( # grad_shard = optimizer.get_grad_shard(
unwrapped_model.language_model.embedding.position_embeddings.weight) # unwrapped_model.language_model.embedding.position_embeddings.weight)
torch.distributed.all_reduce(grad_shard, # torch.distributed.all_reduce(grad_shard,
group=mpu.get_position_embedding_group()) # group=mpu.get_position_embedding_group())
# <<< # <<<
timers('backward-embedding-all-reduce').stop() timers('backward-embedding-all-reduce').stop()
...@@ -629,9 +645,111 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -629,9 +645,111 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# >>> # >>>
class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params): class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
def __init__(self, *args):
super().__init__(*args)
self.initialized = False
# >>>
self.initialize()
# <<<
def initialize(self):
# >>>
import math
# <<<
if self.initialized:
raise Exception("initialization worked.")
return
self.initialized = True
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
total_param_size = sum(
p.numel()
for g in self.param_groups
for p in g["params"]
)
shard_size = int(math.ceil(total_param_size / data_parallel_world_size))
shard_start_index = data_parallel_rank * shard_size
shard_end_index = min(total_param_size, shard_start_index + shard_size)
self.shard_size = shard_end_index - shard_start_index
# allocate_shard = lambda dtype : torch.empty(
# [self.shard_size],
# dtype = dtype,
# device = torch.cuda.current_device())
allocate_shard = lambda dtype : MemoryBuffer(self.shard_size, dtype)
self.main_param_shard = allocate_shard(torch.float)
self.main_grad_shard = allocate_shard(torch.float)
self.adam_m_shard = allocate_shard(torch.float)
self.adam_v_shard = allocate_shard(torch.float)
def reduce_gradients(self, model):
# >>>
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
# from megatron import get_timers
# from megatron.model import DistributedDataParallel as LocalDDP
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
args = get_args()
# timers = get_timers()
# <<<
# >>>
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# grad_buffers = [ m._grad_buffers for m in model ]
for virtual_model in model:
grad_buffers = virtual_model._grad_buffers
for dtype, grad_buffer in grad_buffers.items():
dp_grad_buffers = [
grad_buffer.get(self.shard_sizes[i],
self.shard_start_indexes[i])
for i in self.data_parallel_world_size]
pax(0, {"dp_grad_buffers": dp_grad_buffers})
torch.distributed.reduce_scatter(
self.main_grad_shard,
grad_buffer.data,
group = mpu.get_data_parallel_group(),
)
# >>>
pax(0, {
"virtual_model" : virtual_model,
"grad_buffers" : grad_buffers,
"dtype" : dtype,
"grad_buffer / len" : grad_buffer.numel,
"grad_buffer / data" : tp(grad_buffer.data),
# "optimizer" : self.optimizer,
"main_grad_shard" : tp(self.main_grad_shard),
})
# <<<
# >>>
from lutil import pax, tp
pax(0, {
"model" : model,
"grad_buffers" : grad_buffers,
"grad_buffers / 0" : grad_buffers[0],
"grad_buffers / 0 / data" :tp(list(grad_buffers[0].values())[0].data),
})
# <<<
def step(self): def step(self):
raise Exception("hi.") raise Exception("step.")
# <<< # <<<
......
...@@ -427,12 +427,12 @@ def train_step(forward_step_func, data_iterator, ...@@ -427,12 +427,12 @@ def train_step(forward_step_func, data_iterator,
# >>> # >>>
# Reduce gradients. (with distributed optimizer option, optimizer # Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients) # now responsible for reducing gradients)
optimizer.reduce_gradients() optimizer.reduce_gradients(model)
# <<< # <<<
# >>> # >>>
from lutil import pax # from lutil import pax
pax({"optimizer": optimizer}) # pax(0, {"optimizer": optimizer})
# <<< # <<<
# Update parameters. # Update parameters.
...@@ -440,6 +440,12 @@ def train_step(forward_step_func, data_iterator, ...@@ -440,6 +440,12 @@ def train_step(forward_step_func, data_iterator,
update_successful, grad_norm, num_zeros_in_grad = optimizer.step() update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
# >>>
# Gather params gradients. (with distributed optimizer option, optimizer
# now responsible for gathering updated params)
optimizer.gather_params()
# <<<
# Update learning rate. # Update learning rate.
if update_successful: if update_successful:
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment