Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ae6cd9b5
Commit
ae6cd9b5
authored
Feb 18, 2022
by
Lawrence McAfee
Browse files
some cleanup
parent
e5db0fda
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
427 deletions
+49
-427
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+49
-427
No files found.
megatron/optimizer/optimizer.py
View file @
ae6cd9b5
...
...
@@ -233,6 +233,10 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return
self
.
grad_scaler
.
scale
def
reload_model_params
(
self
):
self
.
_copy_model_params_to_main_params
()
def
_unscale_main_grads_and_check_for_nan
(
self
):
# Collect main grads.
...
...
@@ -269,6 +273,11 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# pax(0, {
# "params" : self.get_parameters(), # self.main_param_shards,
# "grads" : [ p.grad for p in self.get_parameters() ], # self.main_param_shards ],
# })
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
...
...
@@ -284,6 +293,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# If we found inf/nan, skip the update.
if
found_inf_flag
:
pax
(
0
,
{
"found_inf_flag"
:
found_inf_flag
})
return
False
,
None
,
None
# Clip the main gradients.
...
...
@@ -301,12 +311,16 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self
.
optimizer
.
step
()
# >>>
# from lutil import pax, tp
# pax(0, {
# "optimizer / state" :
# { hash(k):tp(v) for k,v in self.optimizer.state.items() },
# "optimizer / state / len" : len(self.optimizer.state),
# "optimizer / state / 0" : list(self.optimizer.state.values())[0],
# # "optimizer / state" :
# # { hash(k):tp(v) for k,v in self.optimizer.state.items() },
# # "optimizer / state / len" : len(self.optimizer.state),
# # "optimizer / state / 0" : list(self.optimizer.state.values())[0],
# **{"optimizer / state / %s" % hash(k) : tp(v) for k, v in self.optimizer.state.items() },
# "params" : sum(
# s["exp_avg"].numel()
# for s in self.optimizer.state.values()
# ),
# })
# <<<
...
...
@@ -536,8 +550,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers
(
'backward-embedding-all-reduce'
).
stop
()
def
gather_params
(
self
):
raise
Exception
(
"hi."
)
pass
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the float16 group.
...
...
@@ -621,10 +634,6 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
reload_model_params
(
self
):
self
.
_copy_model_params_to_main_params
()
def
state_dict
(
self
):
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
...
...
@@ -669,13 +678,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# >>>
import
math
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from
megatron
import
get_args
# from megatron import get_timers
# from megatron.model import DistributedDataParallel as LocalDDP
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
# class ShardIndex:
class
Shard
:
...
...
@@ -726,230 +729,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
})
# <<<
# def __init__(self, *_args):
# super().__init__(*_args)
# def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
# params_have_main_grad, use_contiguous_buffers_in_local_ddp,
# bf16, grad_scaler):
# super().__init__(
# optimizer, clip_grad, log_num_zeros_in_grad,
# params_have_main_grad, use_contiguous_buffers_in_local_ddp,
# bf16, grad_scaler)
# # >>>
# # self.test_reduce_scatter()
# # <<<
# # >>>
# args = get_args()
# # <<<
# # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group()
# self.data_parallel_rank = mpu.get_data_parallel_rank()
# self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Total trainable param count.
# # self.total_param_size = sum(
# # p.numel()
# # for g in self.param_groups
# # for p in g["params"]
# # # if p .requires_grad ???
# # )
# # Model params: group sizes, group offset maps.
# # self.model_params = []
# # self.model_param_group_sizes = []
# # self.model_param_group_offset_maps = []
# self.model_param_groups = []
# for param_group in self.optimizer.param_groups:
# param_group_offset = 0
# param_group_offset_map = {}
# for param in param_group['params']:
# if not param.requires_grad:
# continue
# # self.model_params.append(param)
# param_group_offset_map[param] = {
# "start" : param_group_offset,
# "end" : param_group_offset + param.numel(),
# }
# param_group_offset += param.numel()
# # self.model_param_group_sizes.append(param_group_offset)
# # self.model_param_group_offset_maps.append(param_group_offset_map)
# self.model_param_groups.append({
# "size" : param_group_offset,
# "offset_map" : param_group_offset_map,
# })
# # pax(0, {
# # "model_params" : model_params,
# # "model_param_group_sizes" : model_param_group_sizes,
# # "model_param_group_offset_maps" : model_param_group_offset_maps,
# # })
# # Shard allocator.
# # ** torch.nn.Parameter ??
# # ** MemoryBuffer ??
# allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# # Allocate shards.
# # (Also, collect world DP shard info.)
# # model_main_dtypes = set([ args.params_dtype, torch.float ])
# model_main_dtypes = set([ torch.float ]) # fp32 only, for now
# # self.world_shard_info_groups = [] # world_group_shard_infos ?
# # self.main_param_shard_groups = []
# self.world_shard_infos = [{"groups": []} for _ in self.model_param_groups]
# for group_index, model_param_group in enumerate(self.model_param_groups):
# # Max world shard size.
# model_param_size = model_param_group["size"]
# max_world_shard_size = int(math.ceil(model_param_size /
# self.data_parallel_world_size))
# # DP world shard infos.
# # world_shard_infos = []
# for r in range(self.data_parallel_world_size):
# shard_start_index = r * max_world_shard_size
# shard_end_index = min(model_param_size,
# shard_start_index + max_world_shard_size)
# # world_shard_infos.append({
# self.world_shard_infos[r]["groups"].append({
# "start" : shard_start_index,
# "end" : shard_end_index,
# "size" : shard_end_index - shard_start_index,
# })
# # self.world_shard_info_groups.append(world_shard_infos)
# # self.world_shard_infos[group_index].append(world_shard_infos)
# # DP local rank's shard info.
# # local_shard_info = world_shard_infos[self.data_parallel_rank]
# local_shard_info = \
# self.world_shard_infos[self.data_parallel_rank]["groups"][-1]
# local_shard_start_index = local_shard_info["start"]
# local_shard_end_index = local_shard_info["end"]
# local_shard_size = local_shard_info["size"]
# # Local shard's param 'slice' index map.
# local_shard_info["param_slice_index_map"] = {}
# for param, offset_dict in model_param_group["offset_map"].items():
# # param_start_index = offset_dict["start"]
# # param_end_index = offset_dict["end"]
# # param_shard_start_index = max(local_shard_start_index,
# # param_start_index)
# # param_shard_end_index = min(local_shard_end_index,
# # param_end_index)
# orig_start_index = offset_dict["start"]
# orig_end_index = offset_dict["end"]
# shard_start_index = max(
# 0,
# orig_start_index - local_shard_start_index)
# shard_end_index = min(
# local_shard_end_index,
# orig_end_index - local_shard_start_index)
# # if param_shard_end_index > param_shard_start_index:
# # # Indexes are relative to local shard start index.
# # # local_shard_info["param_index_map"][param] = {
# # # "param" : (
# # # param_shard_start_index,
# # # param_shard_end_index,
# # # ),
# # # "shard" : (
# # # param_shard_start_index - local_shard_start_index,
# # # param_shard_end_index - local_shard_start_index,
# # # ),
# # # }
# # local_shard_info["param_slice_index_map"][param] = {
# # "param_start" :
# # param_shard_start_index,
# # "shard_start" :
# # param_shard_start_index - local_shard_start_index,
# # "size":
# # param_shard_end_index - param_shard_start_index,
# # }
# if shard_end_index > shard_start_index:
# local_shard_info["param_slice_index_map"][param] = {
# "orig_start" : orig_start_index,
# "shard_start" : shard_start_index,
# "size" : shard_end_index - shard_start_index,
# }
# # pax(0, {
# # "local index" : "%d, %d" % (
# # local_shard_start_index,
# # local_shard_end_index,
# # ),
# # "param index" : "%s, %d" % (
# # param_start_index,
# # param_end_index,
# # ),
# # "param" : tp(param),
# # "shard_param_index_map" : shard_param_index_map,
# # "local_shard_info" : local_shard_info,
# # })
# # pax(2, {
# # "data_parallel_rank" : self.data_parallel_rank,
# # "local_shard_info" : local_shard_info,
# # "param_index_map " : [
# # (str(p.shape), i)
# # for p, i in local_shard_info["param_index_map"].items()
# # ],
# # })
# # Allocate shards.
# # (Non-fp32 shards are for convenience; e.g., intermediaries
# # between model params and main fp32 shard. Necessary???)
# # main_param_shards = {
# # ty : allocate_shard(local_shard_size, ty)
# # for ty in model_main_dtypes}
# main_param_shards = {}
# for dtype in model_main_dtypes:
# main_param = allocate_shard(local_shard_size, dtype)
# main_param.grad = allocate_shard(local_shard_size, dtype)
# # pax(0, {"main_param": main_param})
# main_param_shards[dtype] = main_param
# # self.main_param_shard_groups.append(main_param_shards)
# local_shard_info["data"] = main_param_shards
# # Update optimizer group.
# self.optimizer.param_groups[group_index]["params"] = \
# [ main_param_shards[torch.float] ]
# # pax(0, {
# # "param_groups" : self.optimizer.param_groups,
# # "params" : self.optimizer.param_groups[group_index]["params"],
# # })
# # Add world start/end indexes, for reduce/gather steps.
# offset = 0
# for r in self.world_shard_infos:
# r["start_index"] = offset
# offset += sum(g["size"] for g in r["groups"])
# r["end_index"] = offset
# # Leverage state_dict() and load_state_dict() to
# # recast preexisting per-param state tensors
# self.optimizer.load_state_dict(self.optimizer.state_dict())
# # >>>
# # pax(0, {
# # "world_shard_infos" : self.world_shard_infos,
# # **{
# # "world_shard_infos / %d" % i : r
# # for i, r in enumerate(self.world_shard_infos)
# # },
# # })
# # <<<
@
classmethod
# def get_ddp_gbuf_param_shards(cls, model, dtype, gbuf_start):
# def get_ddp_gbuf_param_shard_map(cls, model, dtype, gbuf_start):
# def get_model_gbuf_param_shard_index_map(cls,model,dtype,gbuf_world_index):
def
get_model_gbuf_param_shard_map
(
cls
,
model
,
dtype
,
gbuf_world_shard
):
# Param shard map.
...
...
@@ -980,9 +760,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return
param_shard_map
@
classmethod
# def get_ddp_gbuf_shard(cls, model, dtype):
# def get_model_gbuf_shard(cls, model, dtype):
# def get_model_gbuf_shard_index(cls, model, dtype):
def
get_model_gbuf_shard
(
cls
,
model
,
dtype
):
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
...
...
@@ -1001,7 +778,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_world_all_shards
.
append
(
gbuf_world_shard
)
gbuf_world_shard
=
gbuf_world_all_shards
[
data_parallel_rank
]
gbuf_local_shard
=
gbuf_world_shard
.
normalize
()
# gbuf_local_shard = Shard(0, gbuf_world_index.size)
# Param shards.
param_shard_map
=
cls
.
get_model_gbuf_param_shard_map
(
model
,
...
...
@@ -1021,10 +797,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return
data
@
classmethod
# def get_ddp_gbuf_shards(cls, model):
# def get_ddp_gbuf_shard_map(cls, model):
# def get_model_gbuf_shard_map(cls, model):
# def get_model_gbuf_shard_index_map(cls, model):
def
get_model_gbuf_shard_map
(
cls
,
model
):
# shard_index_map = {
...
...
@@ -1038,11 +810,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return
shard_map
@
classmethod
# def get_param_size_map(cls, model_gbuf_shards):
# def get_param_model_gbuf_map(cls, model_gbuf_shards):
def
get_param_gbuf_map
(
cls
,
model_gbuf_shards
):
# param_size_map = {}
param_gbuf_map
=
{}
for
model_index
,
model_gbuf_shard_map
in
enumerate
(
model_gbuf_shards
):
for
dtype
,
gbuf_shard_map
in
model_gbuf_shard_map
.
items
():
...
...
@@ -1064,7 +833,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "param_gbuf_map" : param_gbuf_map,
# })
# return param_size_map
return
param_gbuf_map
@
classmethod
...
...
@@ -1120,8 +888,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert
args
.
use_contiguous_buffers_in_local_ddp
# already checked in args
# <<<
# pax(1, {"models": models})
# # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group()
# self.data_parallel_rank = mpu.get_data_parallel_rank()
...
...
@@ -1138,8 +904,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self
.
optimizer
.
param_groups
,
self
.
model_gbuf_shards
)
# pax(0, {"opt_group_shards": self.opt_group_shards})
# Allocate main param/grad shard.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
...
...
@@ -1165,6 +929,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Update optimizer group.
self
.
optimizer
.
param_groups
[
group_index
][
"params"
]
=
[
main_param
]
# Initialize main params.
self
.
_copy_model_params_to_main_params
()
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
...
...
@@ -1177,11 +944,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# })
# <<<
# def get_loss_scale(self):
# if self.grad_scaler is None:
# return self._scale_one
# return self.grad_scaler.scale
def
load_state_dict
(
self
):
raise
Exception
(
"hi."
)
def
reload_model_params
(
self
):
...
...
@@ -1189,21 +951,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def
state_dict
(
self
):
raise
Exception
(
"hi."
)
# def zero_grad(self, set_to_none=True):
# params = []
# for model_param_group in self.model_param_groups:
# params.extend(model_param_group["offset_map"].keys())
# for main_group in self.optimizer.param_groups:
# params.extend(main_group["params"])
# # _zero_grad_group_helper(params, set_to_none)
# _zero_grad_group_helper(params, set_to_none = False)
# # pax(0, {
# # "model_param_groups" : self.model_param_groups,
# # "params" : params,
# # })
def
zero_grad
(
self
,
set_to_none
=
True
):
model_params
=
[]
...
...
@@ -1219,110 +966,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {"params": params})
# def reduce_gradients(self, model):
# # >>>
# # pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# # <<<
# # >>>
# args = get_args()
# # timers = get_timers()
# # <<<
# # >>> [ temporary requirement ... and already checked in arguments.py ]
# assert args.use_contiguous_buffers_in_local_ddp
# # <<<
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# # Map param to virtual model.
# # ** ideally, this should happen once, during construction.
# param_model_map = {}
# for vmodel in model:
# for dtype, param_index_map in \
# vmodel._grad_buffer_param_index_map.items():
# for param in param_index_map:
# param_model_map[param] = {
# "dtype" : dtype,
# "model" : vmodel,
# }
# # pax(0, {
# # "param_model_map" : [
# # (str(tuple(p.shape)), m)
# # for p, m in param_model_map.items()
# # ],
# # })
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# # Copy model grads to main shard.
# local_shard_info_groups = [g[self.data_parallel_rank]
# for g in self.world_shard_info_groups]
# for group_index, local_shard_info in enumerate(local_shard_info_groups):
# # model_param_index_map =
# # shard_param_index_map = local_shard_info["param_index_map"]
# # main_index_map = local_shard_info["param_index_map"]
# main_slice_index_map = local_shard_info["param_slice_index_map"]
# for param, main_slice_indexes in main_slice_index_map.items():
# main_slice_orig_start_index = main_slice_indexes["orig_start"]
# main_slice_shard_start_index = main_slice_indexes["shard_start"]
# main_slice_size = main_slice_indexes["size"]
# dtype_model_dict = param_model_map[param]
# dtype = dtype_model_dict["dtype"]
# vmodel = dtype_model_dict["model"]
# model_grad_buffer = vmodel._grad_buffers[dtype].data
# model_grad_buffer_start_index = \
# vmodel._grad_buffer_param_index_map[dtype][param][0] + \
# main_slice_orig_start_index
# main_grad_view = local_shard_info["data"][torch.float].grad[
# main_slice_shard_start_index:
# main_slice_shard_start_index + main_slice_size
# ]
# model_grad_view = model_grad_buffer[
# model_grad_buffer_start_index:
# model_grad_buffer_start_index + main_slice_size
# ]
# main_grad_view.detach().copy_(model_grad_view)
# # pax(0, {
# # # "local_shard_info" : local_shard_info,
# # "main_slice_orig_start_index" : main_slice_orig_start_index,
# # "main_slice_shard_start_index" : main_slice_shard_start_index,
# # "main_slice_size" : main_slice_size,
# # "model_grad_buffer_start_index" :
# # model_grad_buffer_start_index,
# # "main_grad_view" : tp(main_grad_view),
# # "main_grad_view / detach" : tp(main_grad_view.detach()),
# # "model_grad_view" : tp(model_grad_view),
# # })
# # pax(0, {
# # "group_index" : group_index,
# # "local_shard_info" : local_shard_info,
# # "shard_param_index_map" : shard_param_index_map,
# # "param" : tp(param),
# # "shard_indexes" : shard_indexes,
# # "grad_buffer_indexes" : grad_buffer_indexes,
# # })
# pax(0, {
# # "world_shard_info_groups" : self.world_shard_info_groups,
# # **{"world_shard_info_groups / %d" % i : v
# # for i, v in enumerate(self.world_shard_info_groups)},
# # "local_shard_info_groups" : local_shard_info_groups,
# "local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ],
# })
def
get_model_grad_buffer_dp_views
(
self
):
# >>>
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
args
=
get_args
()
assert
args
.
use_contiguous_buffers_in_local_ddp
# <<<
# Grad buffer views.
gbuf_view_items
=
[]
...
...
@@ -1343,11 +993,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def
reduce_gradients
(
self
,
model
):
# >>>
args
=
get_args
()
# timers = get_timers()
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params.
...
...
@@ -1360,36 +1005,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
# assert args.use_contiguous_buffers_in_local_ddp
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# for model_index, model in enumerate(self.models):
# for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
# world_shards = gbuf_shard["world_all"]
# gbuf = model._grad_buffers[dtype]
# gbuf_views = []
# for shard in world_shards:
# gbuf_views.append(gbuf.data[shard.start:shard.end])
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# # pax(0, {
# # "model_index" : model_index,
# # "model" : model,
# # "dtype" : str(dtype),
# # "gbuf_shard" : gbuf_shard,
# # "world_shards" : world_shards,
# # "gbuf_views" : gbuf_views,
# # })
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
...
...
@@ -1411,6 +1026,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
# All-gather updated main params.
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
torch
.
distributed
.
all_gather
(
gbuf_views
,
...
...
@@ -1418,15 +1034,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group
=
data_parallel_group
,
)
# for param, (model_index, dtype) in self.param_gbuf_map.items():
# gbuf = self.model_gbuf_shards[model_index][dtype]
# pax(0, {
# "param" : tp(param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "gbuf" : gbuf,
# })
# Each model param now contains its updated values in it's
# '.main_grad' field.
for
param
in
self
.
param_gbuf_map
:
param
.
detach
().
copy_
(
param
.
main_grad
)
# pax(0, {
...
...
@@ -1443,15 +1052,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ],
# })
# def step(self):
# raise Exception("step.")
def
_collect_main_grad_data_for_unscaling
(
self
):
# pax(1, {
# "main_param_shards" : self.main_param_shards,
# "grads" : [ p.grad.data for p in self.main_param_shards ],
# })
return
[
p
.
grad
.
data
for
p
in
self
.
main_param_shards
]
def
_copy_model_grads_to_main_grads
(
self
):
...
...
@@ -1465,11 +1066,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert
main_shard
.
size
==
model_shard
.
size
# Copy from DDP's contiguous buffer to main shard's grad.
model_grad_tensor
=
\
self
.
models
[
model_index
].
_grad_buffers
[
gbuf_dtype
].
data
main_grad_tensor
=
\
self
.
main_param_shards
[
group_index
].
grad
# Copy sub-range within tensor.
model_grad_view
=
\
model_grad_tensor
[
model_shard
.
start
:
model_shard
.
end
]
main_grad_view
=
\
...
...
@@ -1491,10 +1094,20 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "main_shard" : str(main_shard),
# })
# >>>
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
for
param
in
self
.
main_param_shards
:
grad
=
param
.
grad
is_nan
=
torch
.
any
(
torch
.
isnan
(
grad
)).
item
()
if
is_nan
:
pax
(
0
,
{
"grad"
:
tp
(
grad
),
"is_nan"
:
is_nan
,
})
# <<<
def
_copy_main_params_to_model_params
(
self
):
...
...
@@ -1538,6 +1151,15 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
# >>>
for
param
in
self
.
param_gbuf_map
:
is_nan
=
torch
.
any
(
torch
.
isnan
(
param
)).
item
()
if
is_nan
:
pax
(
0
,
{
"param"
:
tp
(
param
),
"is_nan"
:
is_nan
,
})
# <<<
# <<<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment