Commit 5706ba42 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

bit more progress

parent f48e1f29
...@@ -121,11 +121,21 @@ class MegatronOptimizer(ABC): ...@@ -121,11 +121,21 @@ class MegatronOptimizer(ABC):
return self.get_loss_scale() * loss return self.get_loss_scale() * loss
@abstractmethod
def reduce_gradients(self):
pass
@abstractmethod @abstractmethod
def step(self): def step(self):
pass pass
@abstractmethod
def gather_params(self):
pass
@abstractmethod @abstractmethod
def reload_model_params(self): def reload_model_params(self):
"""Refreshes any internal state from the current model parameters. """Refreshes any internal state from the current model parameters.
...@@ -170,36 +180,13 @@ class MegatronOptimizer(ABC): ...@@ -170,36 +180,13 @@ class MegatronOptimizer(ABC):
class Float16OptimizerWithFloat16Params(MegatronOptimizer): class BaseFloat16Optimizer(MegatronOptimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad, def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp, params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler): bf16, grad_scaler):
super(Float16OptimizerWithFloat16Params, self).__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp) params_have_main_grad, use_contiguous_buffers_in_local_ddp)
...@@ -228,6 +215,48 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -228,6 +215,48 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if self.grad_scaler is None: if self.grad_scaler is None:
self._scale_one = torch.cuda.FloatTensor([1.0]) self._scale_one = torch.cuda.FloatTensor([1.0])
def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale
# class Float16OptimizerWithFloat16Params(MegatronOptimizer):
class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
"""
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler):
super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler)
# ====================== # ======================
# main parameter stuff # main parameter stuff
# ====================== # ======================
...@@ -319,29 +348,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -319,29 +348,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict()) self.optimizer.load_state_dict(self.optimizer.state_dict())
# >>>
# from lutil import pax
# pax(0, {
# # "float16_groups / len" : [ len(g) for g in self.float16_groups ],
# # "fp32_from_float16_groups / len" :
# # [ len(g) for g in self.fp32_from_float16_groups ],
# # "float16_groups / 0" : self.float16_groups[0],
# # "float16_groups / 1" : self.float16_groups[1],
# # "fp32_from_float16_groups / 0" : self.fp32_from_float16_groups[0],
# # "fp32_from_float16_groups / 1" : self.fp32_from_float16_groups[1],
# # "fp32_from_float32_groups" : self.fp32_from_fp32_groups,
# "optimizer" : self.optimizer,
# # "optimizer / sd" : self.optimizer.state_dict(),
# # "optimizer / state" : self.optimizer.state_dict()["state"],
# # "optimizer / pg" : self.optimizer.state_dict()["param_groups"],
# # "optimizer / pg / 0" : self.optimizer.state_dict()["param_groups"][0],
# # "optimizer / pg / 1" : self.optimizer.state_dict()["param_groups"][1],
# "optimizer -> pg" : optimizer.param_groups,
# "optimizer -> pg / 0" : optimizer.param_groups[0]["params"],
# "optimizer -> pg / 1" : optimizer.param_groups[1]["params"],
# })
# <<<
def zero_grad(self, set_to_none=True): def zero_grad(self, set_to_none=True):
"""We only need to zero the model related parameters, i.e., """We only need to zero the model related parameters, i.e.,
...@@ -357,12 +363,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -357,12 +363,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
_zero_grad_group_helper(group, set_to_none) _zero_grad_group_helper(group, set_to_none)
def get_loss_scale(self):
if self.grad_scaler is None:
return self._scale_one
return self.grad_scaler.scale
# >>> # >>>
def reduce_gradients(self, model): def reduce_gradients(self, model):
...@@ -658,7 +658,8 @@ from lutil import pax, tp ...@@ -658,7 +658,8 @@ from lutil import pax, tp
# <<< # <<<
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params): # class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
class Float16DistributedOptimizer(MegatronOptimizer): # class Float16DistributedOptimizer(MegatronOptimizer):
class Float16DistributedOptimizer(BaseFloat16Optimizer):
# >>> # >>>
@classmethod @classmethod
...@@ -702,7 +703,8 @@ class Float16DistributedOptimizer(MegatronOptimizer): ...@@ -702,7 +703,8 @@ class Float16DistributedOptimizer(MegatronOptimizer):
super().__init__( super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp) params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler)
# >>> # >>>
# self.test_reduce_scatter() # self.test_reduce_scatter()
...@@ -759,34 +761,41 @@ class Float16DistributedOptimizer(MegatronOptimizer): ...@@ -759,34 +761,41 @@ class Float16DistributedOptimizer(MegatronOptimizer):
allocate_shard = lambda shard_size, dtype : torch.empty( allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,), (shard_size,),
dtype = dtype, dtype = dtype,
device = torch.cuda.current_device()) device = torch.cuda.current_device(),
requires_grad = True)
# return torch.nn.Parameter ?
# allocate_shard = lambda dtype : MemoryBuffer(self.shard_size, dtype) # allocate_shard = lambda dtype : MemoryBuffer(self.shard_size, dtype)
# Collect DP world shard infos, per group. # Allocate shards.
# (Also, collect world DP shard info.)
model_main_dtypes = set([ args.params_dtype, torch.float ]) model_main_dtypes = set([ args.params_dtype, torch.float ])
self.world_shard_info_groups = [] # world_group_shard_infos ? self.world_shard_info_groups = [] # world_group_shard_infos ?
self.main_param_shard_groups = [] self.main_param_shard_groups = []
for model_param_group_size in model_param_group_sizes: for group_index, model_param_group in enumerate(self.model_param_groups):
max_world_shard_size = int(math.ceil(model_param_group_size / model_param_size = model_param_group["size"]
max_world_shard_size = int(math.ceil(model_param_size /
self.data_parallel_world_size)) self.data_parallel_world_size))
# Group shard infos. # DP world shard infos.
shard_infos = [] world_shard_infos = []
for r in range(self.data_parallel_world_size): for r in range(self.data_parallel_world_size):
shard_start_index = r * max_shard_size shard_start_index = r * max_world_shard_size
shard_end_index = min(self.total_param_size, shard_end_index = min(model_param_size,
shard_start_index + max_shard_size) shard_start_index + max_world_shard_size)
shard_infos.append({ world_shard_infos.append({
"start" : shard_start_index, "start" : shard_start_index,
"end" : shard_end_index, "end" : shard_end_index,
"size" : shard_end_index - shard_start_index, "size" : shard_end_index - shard_start_index,
}) })
self.world_shard_info_groups.append(shard_infos) self.world_shard_info_groups.append(world_shard_infos)
# pax(0, {"world_shard_infos": world_shard_infos})
# Allocate shards. # Allocate shards.
local_shard_size = \ # (Non-fp32 shards are for convenience; e.g., intermediaries
self.world_shard_infos[self.data_parallel_rank]["size"] # between model params and main fp32 shard. Necessary???)
local_shard_size = world_shard_infos[self.data_parallel_rank]["size"]
# # self.main_param_shard = allocate_shard(torch.float) # # self.main_param_shard = allocate_shard(torch.float)
# # self.main_grad_shard = allocate_shard(torch.float) # # self.main_grad_shard = allocate_shard(torch.float)
...@@ -795,29 +804,50 @@ class Float16DistributedOptimizer(MegatronOptimizer): ...@@ -795,29 +804,50 @@ class Float16DistributedOptimizer(MegatronOptimizer):
# self.adam_m_shard = allocate_shard(torch.float) # self.adam_m_shard = allocate_shard(torch.float)
# self.adam_v_shard = allocate_shard(torch.float) # self.adam_v_shard = allocate_shard(torch.float)
self.main_param_shard_groups.append({ty:allocate_shard(ty) main_param_shards = {
for ty in model_main_dtypes}) ty : allocate_shard(local_shard_size, ty)
for ty in model_main_dtypes}
self.main_param_shard_groups.append(main_param_shards)
# Update optimizer group.
self.optimizer.param_groups[group_index]["params"] = \
[ main_param_shards[torch.float] ]
# >>>
# pax(0, { # pax(0, {
# "total_param_size" : self.total_param_size, # "param_groups" : self.optimizer.param_groups,
# "max_shard_size" : max_shard_size, # "params" : self.optimizer.param_groups[group_index]["params"],
# "shard_infos" : self.shard_infos,
# "shard_size" : shard_size,
# "param_shard_map" : self.param_shard_map,
# }) # })
# <<<
def get_loss_scale(self): # Leverage state_dict() and load_state_dict() to
raise Exception("hi.") # recast preexisting per-param state tensors
self.optimizer.load_state_dict(self.optimizer.state_dict())
# def get_loss_scale(self):
# if self.grad_scaler is None:
# return self._scale_one
# return self.grad_scaler.scale
def load_state_dict(self): def load_state_dict(self):
raise Exception("hi.") raise Exception("hi.")
def reload_model_params(self): def reload_model_params(self):
raise Exception("hi.") raise Exception("hi.")
def state_dict(self): def state_dict(self):
raise Exception("hi.") raise Exception("hi.")
def zero_grad(self):
raise Exception("hi.") def zero_grad(self, set_to_none=True):
params = []
for model_param_group in self.model_param_groups:
params.extend(model_param_group["offset_map"].keys())
for main_group in self.optimizer.param_groups:
params.extend(main_group["params"])
_zero_grad_group_helper(params, set_to_none)
# pax(0, {
# "model_param_groups" : self.model_param_groups,
# "params" : params,
# })
def reduce_gradients(self, model): def reduce_gradients(self, model):
...@@ -880,10 +910,16 @@ class Float16DistributedOptimizer(MegatronOptimizer): ...@@ -880,10 +910,16 @@ class Float16DistributedOptimizer(MegatronOptimizer):
}) })
# <<< # <<<
def step(self): def step(self):
raise Exception("step.") raise Exception("step.")
def gather_params(self):
raise Exception("gather params.")
# <<< # <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment