"git@developer.sourcefind.cn:OpenDAS/torch-spline-conv.git" did not exist on "67904212918f0d9114a3e87f1ac0dbcd958f2ab6"
Commit 11581195 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

removed mem savings factor. +cleanup.

parent 41038d54
...@@ -733,12 +733,6 @@ def _add_distributed_args(parser): ...@@ -733,12 +733,6 @@ def _add_distributed_args(parser):
'affects the encoder embedding.)') 'affects the encoder embedding.)')
group.add_argument('--use-distributed-optimizer', action='store_true', group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.') help='Use distributed optimizer.')
group.add_argument('--distrib-opt-comm-mem-savings', default=0., type=float,
help='Trade-off memory savings & iteration time, for '
'disributed optimizer\'s communication operations (i.e., '
'(reduce/gather). This value ranges from 0.0 (default, '
'no memory savings) to 1.0 (max memory savings, at the '
'expense of iteration time).')
return parser return parser
......
...@@ -51,31 +51,8 @@ def clip_grad_norm_fp32(parameters, grads_for_norm, ...@@ -51,31 +51,8 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
# >>>
# # Filter parameters based on:
# # - grad should not be none
# # - parameter should not be shared
# # - should not be a replica due to tensor model parallelism
# grads = []
# grads_for_norm = []
# for param in parameters:
# grad_not_none = param.grad is not None
# is_not_shared = param_is_not_shared(param)
# is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
# if grad_not_none:
# grad = param.grad.detach()
# if grad_not_none:
# # Make sure the grads are in fp32
# assert param.grad.type() == 'torch.cuda.FloatTensor'
# grads.append(grad)
# if grad_not_none and is_not_shared and is_not_tp_duplicate:
# grads_for_norm.append(grad)
# <<<
# >>>
# Grads. # Grads.
grads = [ p.grad for p in parameters if p is not None ] grads = [ p.grad.detach() for p in parameters if p.grad is not None ]
# <<<
# Norm parameters. # Norm parameters.
max_norm = float(max_norm) max_norm = float(max_norm)
...@@ -119,17 +96,6 @@ def clip_grad_norm_fp32(parameters, grads_for_norm, ...@@ -119,17 +96,6 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
group=model_parallel_group) group=model_parallel_group)
total_norm = total_norm.item() ** (1.0 / norm_type) total_norm = total_norm.item() ** (1.0 / norm_type)
# >>>
# from megatron import get_args
# from lutil import pax
# args = get_args()
# pax(0, {
# "use distrib opt" : args.use_distributed_optimizer,
# "norm_type" : norm_type,
# "total_norm" : total_norm,
# })
# <<<
# Scale. # Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6) clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0: if clip_coeff < 1.0:
......
...@@ -22,17 +22,11 @@ import torch ...@@ -22,17 +22,11 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
# >>>
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# <<<
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
# >>>
from .optimizer import get_clippy
from lutil import pax, tp
# <<<
class Shard: class Shard:
def __init__(self, start, end): def __init__(self, start, end):
...@@ -196,12 +190,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -196,12 +190,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Update group's param. # Update group's param.
group_shard["orig_group"]["params"] = [ main_param ] group_shard["orig_group"]["params"] = [ main_param ]
# >>>
@classmethod @classmethod
def get_main_grad_views_for_grad_norm(cls, opt_group_shards, optimizer): def get_main_grad_views_for_grad_norm(cls, opt_group_shards, optimizer):
grad_views = [] grad_views = []
# grad_views_SKIPPED = []
for group_index, opt_group_shard in enumerate(opt_group_shards): for group_index, opt_group_shard in enumerate(opt_group_shards):
opt_grad = optimizer.param_groups[group_index]["params"][0].grad opt_grad = optimizer.param_groups[group_index]["params"][0].grad
for param, shard in opt_group_shard["param_map"].items(): for param, shard in opt_group_shard["param_map"].items():
...@@ -211,30 +203,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -211,30 +203,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
grad_view = opt_grad[shard.start:shard.end] grad_view = opt_grad[shard.start:shard.end]
grad_views.append(grad_view) grad_views.append(grad_view)
# else:
# grad_views_SKIPPED.append(opt_grad[shard.start:shard.end])
# >>>
# my_rank = torch.distributed.get_rank()
# for r in range(torch.distributed.get_world_size()):
# if r == my_rank:
# print("r %d, grad views %s." % (
# my_rank,
# ", ".join(str(tuple(g.shape)) for g in grad_views),
# ))
# torch.distributed.barrier()
# for r in range(torch.distributed.get_world_size()):
# if r == my_rank:
# print("r %d, SKIPPED %s." % (
# my_rank,
# ", ".join(str(tuple(g.shape)) for g in grad_views_SKIPPED),
# ))
# torch.distributed.barrier()
# exit(0)
# <<<
return grad_views return grad_views
# <<<
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
...@@ -274,16 +243,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -274,16 +243,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Initialize main params. # Initialize main params.
self._copy_model_params_to_main_params() self._copy_model_params_to_main_params()
# >>> numel/nelem per rank >>>
# for r in range(torch.distributed.get_world_size()):
# if r == torch.distributed.get_rank():
# for m in self.models:
# for b in m._grad_buffers.values():
# print("r %d, %d." % (r, b.data.nelement()))
# torch.distributed.barrier()
# exit(0)
# <<<
# Params for grad norm. # Params for grad norm.
self.main_grad_views_for_grad_norm = self.get_main_grad_views_for_grad_norm( self.main_grad_views_for_grad_norm = self.get_main_grad_views_for_grad_norm(
self.opt_group_shards, self.opt_group_shards,
...@@ -293,47 +252,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -293,47 +252,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_model_parallel_group(self): def get_model_parallel_group(self):
return None return None
# >>>
# @staticmethod
# def has_nan_debug(tensors):
# if isinstance(tensors, torch.Tensor):
# tensors = [ tensors ]
# assert isinstance(tensors, list)
# has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
# has_nan = any(has_nans)
# return has_nan
# def get_local_model_param_views(self):
# '''** FOR DEBUGGING. **'''
# model_param_views = []
# for group_index, opt_group_shard in enumerate(self.opt_group_shards):
# for param, opt_shard in opt_group_shard["param_map"].items():
# model_index, dtype = self.param_gbuf_map[param]
# gbuf_shard_map = \
# self.model_gbuf_shards[model_index][dtype]["param_map"][param]
# model_param_shard = gbuf_shard_map["param"]
# model_param_views.append(
# param.view(-1)[model_param_shard.start:model_param_shard.end])
# return model_param_views
# def get_local_model_grad_views(self):
# '''** FOR DEBUGGING. **'''
# model_grad_views = []
# for group_index, opt_group_shard in enumerate(self.opt_group_shards):
# for param, opt_shard in opt_group_shard["param_map"].items():
# model_index, dtype = self.param_gbuf_map[param]
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf_shard_map = \
# self.model_gbuf_shards[model_index][dtype]["param_map"][param]
# gbuf_world_shard = gbuf_shard_map["gbuf_world"]
# model_grad_views.append(
# gbuf[gbuf_world_shard.start:gbuf_world_shard.end])
# return model_grad_views
# def get_world_model_params(self):
# '''** FOR DEBUGGING. **'''
# return [ p for m in self.models for p in m.parameters() ]
# def get_world_model_grads(self):
# '''** FOR DEBUGGING. **'''
# return [ p.main_grad for p in self.get_world_model_params() ]
# <<<
def get_main_params(self): def get_main_params(self):
return [ g["params"][0] for g in self.optimizer.param_groups ] return [ g["params"][0] for g in self.optimizer.param_groups ]
...@@ -344,10 +262,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -344,10 +262,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_main_grad(self, group_index): def get_main_grad(self, group_index):
return self.get_main_param(group_index).grad return self.get_main_param(group_index).grad
# >>>
def _get_main_grads_for_grad_norm(self): def get_main_grads_for_grad_norm(self):
return self.main_grad_views_for_grad_norm return self.main_grad_views_for_grad_norm
# <<<
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
...@@ -386,6 +304,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -386,6 +304,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for current_param, saved_param in zip(current_group, saved_group): for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data) current_param.data.copy_(saved_param.data)
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
# Collect model params. # Collect model params.
...@@ -397,6 +316,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -397,6 +316,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Distributed optimizer requires contiguous buffer; don't set to None. # Distributed optimizer requires contiguous buffer; don't set to None.
_zero_grad_group_helper(model_params, set_to_none = False) _zero_grad_group_helper(model_params, set_to_none = False)
def get_model_grad_buffer_dp_views(self): def get_model_grad_buffer_dp_views(self):
data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_world_size = mpu.get_data_parallel_world_size()
...@@ -410,53 +330,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -410,53 +330,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_size = int(gbuf.numel_padded / data_parallel_world_size) shard_size = int(gbuf.numel_padded / data_parallel_world_size)
gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)] gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
for r in range(data_parallel_world_size)] for r in range(data_parallel_world_size)]
# gbuf_view_items.append((model_index, dtype, gbuf_views))
gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views)) gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
return gbuf_view_items return gbuf_view_items
# >>>
# def get_model_grad_buffer_dp_views_SINGLE(self):
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Grad buffer views.
# gbuf_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf in model._grad_buffers.items():
# gbuf_items.append((model_index, dtype, gbuf.data))
# return gbuf_items
# <<<
# >>>
# def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
# # Iterate grad buffers & chunk.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# chunk_view_items = []
# for model_index, dtype, gbuf_views in gbuf_view_items:
# # ** Sanity check. ** (should be unnecessary; see comment above)
# view_numel = gbuf_views[0].nelement()
# for view in gbuf_views:
# assert view.nelement() == view_numel
# # Compute chunk size (via savings factor).
# chunk_numel_min = 131072
# chunk_numel_max = view_numel
# chunk_numel = int(
# mem_savings_factor * chunk_numel_min
# + (1 - mem_savings_factor) * chunk_numel_max
# )
# # Chunk views.
# for start_index in range(0, view_numel, chunk_numel):
# end_index = min(view_numel, start_index + chunk_numel)
# chunk_views = [ t[start_index:end_index] for t in gbuf_views ]
# chunk_view_items.append((model_index, dtype, chunk_views))
# return chunk_view_items
# <<<
def reduce_model_grads(self, args, timers): def reduce_model_grads(self, args, timers):
'''Note: this is a different order of reduction, versus the non- '''Note: this is a different order of reduction, versus the non-
...@@ -474,44 +350,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -474,44 +350,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
mem_savings_factor = args.distrib_opt_comm_mem_savings
# Scale grad buffers by '1 / data_parallel_world_size'. # Scale grad buffers by '1 / data_parallel_world_size'.
for model in self.models: for model in self.models:
for dtype, gbuf in model._grad_buffers.items(): for dtype, gbuf in model._grad_buffers.items():
gbuf.data /= data_parallel_world_size gbuf.data /= data_parallel_world_size
# Reduce scatter all grads. # Reduce-scatter all grads.
# >>>
# gbuf_view_items = \
# self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
# for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# +++
gbuf_view_items = self.get_model_grad_buffer_dp_views() gbuf_view_items = self.get_model_grad_buffer_dp_views()
# gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items): for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
# >>>
# pax(0, {
# "gbuf_view" : gbuf_views[data_parallel_rank].shape,
# "gbuf SINGLE" : gbuf_view_items_SINGLE[index][2].shape,
# })
# <<<
torch.distributed._reduce_scatter_base( torch.distributed._reduce_scatter_base(
gbuf_views[data_parallel_rank], gbuf_views[data_parallel_rank],
gbuf, # gbuf_view_items_SINGLE[index][2], gbuf,
group = data_parallel_group, group = data_parallel_group,
) )
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# <<<
timers('backward-params-all-reduce').stop() timers('backward-params-all-reduce').stop()
def gather_model_params(self, args, timers): def gather_model_params(self, args, timers):
...@@ -520,32 +373,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -520,32 +373,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_rank = mpu.get_data_parallel_rank() data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_group = mpu.get_data_parallel_group() data_parallel_group = mpu.get_data_parallel_group()
mem_savings_factor = args.distrib_opt_comm_mem_savings
# All-gather updated main params. # All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements # - All grad buffer views are guaranteed to have the same num elements
# across all data parallel ranks, with grad buffer padding that is done # across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end # in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks. # indexes across data parallel ranks.
# >>>
# gbuf_view_items = \
# self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
# for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.all_gather(
# gbuf_views,
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
# )
# +++
gbuf_view_items = self.get_model_grad_buffer_dp_views() gbuf_view_items = self.get_model_grad_buffer_dp_views()
# gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items): for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
gbuf, # gbuf_view_items_SINGLE[index][2], gbuf,
gbuf_views[data_parallel_rank], gbuf_views[data_parallel_rank],
group = data_parallel_group, group = data_parallel_group,
) )
# <<<
# Each model param now contains its updated values in its # Each model param now contains its updated values in its
# '.main_grad' field. # '.main_grad' field.
......
...@@ -27,23 +27,11 @@ from megatron import mpu ...@@ -27,23 +27,11 @@ from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from megatron.utils import unwrap_model
from lutil import pax from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
get_clippy = lambda params : [ "%d, %d, %d ... %s" % (
p.grad is not None,
param_is_not_shared(p),
param_is_not_tensor_parallel_duplicate(p),
str(tuple(p.shape)),
) for p in params ]
# <<<
def _zero_grad_group_helper(group, set_to_none): def _zero_grad_group_helper(group, set_to_none):
...@@ -112,12 +100,9 @@ class MegatronOptimizer(ABC): ...@@ -112,12 +100,9 @@ class MegatronOptimizer(ABC):
params.append(param) params.append(param)
return params return params
# >>>
@abstractmethod @abstractmethod
# def get_grads_for_norm(self): def get_main_grads_for_grad_norm(self):
def _get_main_grads_for_grad_norm(self):
pass pass
# <<<
def get_model_parallel_group(self): def get_model_parallel_group(self):
'''Default returned here, but the distributed optimizer overrides this.''' '''Default returned here, but the distributed optimizer overrides this.'''
...@@ -126,7 +111,7 @@ class MegatronOptimizer(ABC): ...@@ -126,7 +111,7 @@ class MegatronOptimizer(ABC):
def clip_grad_norm(self, clip_grad): def clip_grad_norm(self, clip_grad):
params = self.get_parameters() params = self.get_parameters()
grads_for_norm = self._get_main_grads_for_grad_norm() grads_for_norm = self.get_main_grads_for_grad_norm()
return clip_grad_norm_fp32( return clip_grad_norm_fp32(
params, grads_for_norm, clip_grad, params, grads_for_norm, clip_grad,
model_parallel_group=self.get_model_parallel_group()) model_parallel_group=self.get_model_parallel_group())
...@@ -544,17 +529,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -544,17 +529,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# model_params = [ p for m in self.models for p in m.parameters() ]
# optim_params = self.get_parameters()
# model_params.sort(key = lambda p : p.nelement(), reverse = True)
# optim_params.sort(key = lambda p : p.nelement(), reverse = True)
# # assert len(model_params) == len(optim_params
# pax(7, {
# "model_params" : get_clippy(model_params),
# "optim_params" : get_clippy(optim_params),
# })
# <<<
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
...@@ -570,7 +544,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -570,7 +544,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
def _get_main_grads_for_grad_norm(self): def get_main_grads_for_grad_norm(self):
# Filter parameters based on: # Filter parameters based on:
# - grad should not be none # - grad should not be none
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment