"docs/vscode:/vscode.git/clone" did not exist on "0bc6be69606f9ae7f82ad499baf494255c51c38d"
Commit 49cca4d9 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

more work on Float16DistributedOptimizer

parent 329fe582
......@@ -20,7 +20,7 @@ from megatron import get_args
from megatron.model import LayerNorm
# >>>
from .distributed_fused_adam import DistributedFusedAdam
# from .distributed_fused_adam import DistributedFusedAdam
# <<<
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
# >>>
......@@ -106,10 +106,11 @@ def get_megatron_optimizer(model,
# <<<
# >>>
if args.use_distributed_optimizer:
optimizer = DistributedFusedAdam(param_groups)
# if args.use_distributed_optimizer:
# optimizer = DistributedFusedAdam(param_groups)
# elif args.optimizer == 'adam':
# <<<
elif args.optimizer == 'adam':
if args.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
......@@ -167,7 +168,12 @@ def get_megatron_optimizer(model,
# <<<
# FP32.
return FP32Optimizer(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp)
# >>>
opt_ty = Float32DistributedOptimizer \
if args.use_distributed_optimizer \
else Float32Optimizer
return opt_ty(optimizer, args.clip_grad,
args.log_num_zeros_in_grad,
params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp)
# <<<
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
? ? ?
import math
import torch
......
......@@ -29,6 +29,9 @@ from megatron import print_rank_0
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from lutil import pax, tp
# <<<
def _zero_grad_group_helper(group, set_to_none):
"""Zero out the gradient for a group of parameters.
......@@ -361,7 +364,20 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# >>>
def reduce_gradientss(self):
def reduce_gradients(self, model):
# >>>
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.utils import unwrap_model
args = get_args()
timers = get_timers()
# <<<
# >>>
# if not args.use_distributed_optimizer:
......@@ -405,15 +421,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
# >>>
# if args.DDP_impl == 'local':
# grad = word_embeddings_weight.main_grad
# else:
# grad = word_embeddings_weight.grad
# torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
if args.DDP_impl == 'local':
grad = word_embeddings_weight.main_grad
else:
grad = word_embeddings_weight.grad
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
grad_shard = optimizer.get_grad_shard(word_embeddings)
torch.distributed.all_reduce(grad_shard,
group=mpu.get_embedding_group())
# grad_shard = optimizer.get_grad_shard(word_embeddings)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
......@@ -428,13 +444,13 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
assert args.DDP_impl == 'local', \
'T5 model is only supported with local DDP mode'
# >>>
# grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
# torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++
grad_shard = optimizer.get_grad_shard(
unwrapped_model.language_model.embedding.position_embeddings.weight)
torch.distributed.all_reduce(grad_shard,
group=mpu.get_position_embedding_group())
# grad_shard = optimizer.get_grad_shard(
# unwrapped_model.language_model.embedding.position_embeddings.weight)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_position_embedding_group())
# <<<
timers('backward-embedding-all-reduce').stop()
......@@ -629,9 +645,111 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# >>>
class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
def __init__(self, *args):
super().__init__(*args)
self.initialized = False
# >>>
self.initialize()
# <<<
def initialize(self):
# >>>
import math
# <<<
if self.initialized:
raise Exception("initialization worked.")
return
self.initialized = True
data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size()
total_param_size = sum(
p.numel()
for g in self.param_groups
for p in g["params"]
)
shard_size = int(math.ceil(total_param_size / data_parallel_world_size))
shard_start_index = data_parallel_rank * shard_size
shard_end_index = min(total_param_size, shard_start_index + shard_size)
self.shard_size = shard_end_index - shard_start_index
# allocate_shard = lambda dtype : torch.empty(
# [self.shard_size],
# dtype = dtype,
# device = torch.cuda.current_device())
allocate_shard = lambda dtype : MemoryBuffer(self.shard_size, dtype)
self.main_param_shard = allocate_shard(torch.float)
self.main_grad_shard = allocate_shard(torch.float)
self.adam_m_shard = allocate_shard(torch.float)
self.adam_v_shard = allocate_shard(torch.float)
def reduce_gradients(self, model):
# >>>
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
# from megatron import get_timers
# from megatron.model import DistributedDataParallel as LocalDDP
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
args = get_args()
# timers = get_timers()
# <<<
# >>>
assert args.use_contiguous_buffers_in_local_ddp
# <<<
# grad_buffers = [ m._grad_buffers for m in model ]
for virtual_model in model:
grad_buffers = virtual_model._grad_buffers
for dtype, grad_buffer in grad_buffers.items():
dp_grad_buffers = [
grad_buffer.get(self.shard_sizes[i],
self.shard_start_indexes[i])
for i in self.data_parallel_world_size]
pax(0, {"dp_grad_buffers": dp_grad_buffers})
torch.distributed.reduce_scatter(
self.main_grad_shard,
grad_buffer.data,
group = mpu.get_data_parallel_group(),
)
# >>>
pax(0, {
"virtual_model" : virtual_model,
"grad_buffers" : grad_buffers,
"dtype" : dtype,
"grad_buffer / len" : grad_buffer.numel,
"grad_buffer / data" : tp(grad_buffer.data),
# "optimizer" : self.optimizer,
"main_grad_shard" : tp(self.main_grad_shard),
})
# <<<
# >>>
from lutil import pax, tp
pax(0, {
"model" : model,
"grad_buffers" : grad_buffers,
"grad_buffers / 0" : grad_buffers[0],
"grad_buffers / 0 / data" :tp(list(grad_buffers[0].values())[0].data),
})
# <<<
def step(self):
raise Exception("hi.")
raise Exception("step.")
# <<<
......
......@@ -427,12 +427,12 @@ def train_step(forward_step_func, data_iterator,
# >>>
# Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients)
optimizer.reduce_gradients()
optimizer.reduce_gradients(model)
# <<<
# >>>
from lutil import pax
pax({"optimizer": optimizer})
# from lutil import pax
# pax(0, {"optimizer": optimizer})
# <<<
# Update parameters.
......@@ -440,6 +440,12 @@ def train_step(forward_step_func, data_iterator,
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# >>>
# Gather params gradients. (with distributed optimizer option, optimizer
# now responsible for gathering updated params)
optimizer.gather_params()
# <<<
# Update learning rate.
if update_successful:
increment = get_num_microbatches() * \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment