Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
772a4a2d
Commit
772a4a2d
authored
Mar 09, 2022
by
Lawrence McAfee
Browse files
encapsulated 'get_model_parallel_group()'; renamed BaseFloat16Optimizer -> MixedPrecisionOptimizer
parent
6e888151
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
43 deletions
+59
-43
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+45
-39
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+14
-4
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
772a4a2d
...
...
@@ -321,45 +321,51 @@ class DistributedOptimizer(MegatronOptimizer):
# Initialize main params.
self
.
_copy_model_params_to_main_params
()
@
staticmethod
def
has_nan_debug
(
tensors
):
if
isinstance
(
tensors
,
torch
.
Tensor
):
tensors
=
[
tensors
]
assert
isinstance
(
tensors
,
list
)
has_nans
=
[
(
not
torch
.
all
(
torch
.
isfinite
(
t
)).
item
())
for
t
in
tensors
]
has_nan
=
any
(
has_nans
)
return
has_nan
def
get_local_model_param_views
(
self
):
'''** FOR DEBUGGING. **'''
model_param_views
=
[]
for
group_index
,
opt_group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
param
,
opt_shard
in
opt_group_shard
[
"param_map"
].
items
():
model_index
,
dtype
=
self
.
param_gbuf_map
[
param
]
gbuf_shard_map
=
\
self
.
model_gbuf_shards
[
model_index
][
dtype
][
"param_map"
][
param
]
model_param_shard
=
gbuf_shard_map
[
"param"
]
model_param_views
.
append
(
param
.
view
(
-
1
)[
model_param_shard
.
start
:
model_param_shard
.
end
])
return
model_param_views
def
get_local_model_grad_views
(
self
):
'''** FOR DEBUGGING. **'''
model_grad_views
=
[]
for
group_index
,
opt_group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
param
,
opt_shard
in
opt_group_shard
[
"param_map"
].
items
():
model_index
,
dtype
=
self
.
param_gbuf_map
[
param
]
gbuf
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
gbuf_shard_map
=
\
self
.
model_gbuf_shards
[
model_index
][
dtype
][
"param_map"
][
param
]
gbuf_world_shard
=
gbuf_shard_map
[
"gbuf_world"
]
model_grad_views
.
append
(
gbuf
[
gbuf_world_shard
.
start
:
gbuf_world_shard
.
end
])
return
model_grad_views
def
get_world_model_params
(
self
):
'''** FOR DEBUGGING. **'''
return
[
p
for
m
in
self
.
models
for
p
in
m
.
parameters
()
]
def
get_world_model_grads
(
self
):
'''** FOR DEBUGGING. **'''
return
[
p
.
main_grad
for
p
in
self
.
get_world_model_params
()
]
def
get_model_parallel_group
(
self
):
# >>>
# i.e., no param replication across this group
# <<<
return
None
# @staticmethod
# def has_nan_debug(tensors):
# if isinstance(tensors, torch.Tensor):
# tensors = [ tensors ]
# assert isinstance(tensors, list)
# has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
# has_nan = any(has_nans)
# return has_nan
# def get_local_model_param_views(self):
# '''** FOR DEBUGGING. **'''
# model_param_views = []
# for group_index, opt_group_shard in enumerate(self.opt_group_shards):
# for param, opt_shard in opt_group_shard["param_map"].items():
# model_index, dtype = self.param_gbuf_map[param]
# gbuf_shard_map = \
# self.model_gbuf_shards[model_index][dtype]["param_map"][param]
# model_param_shard = gbuf_shard_map["param"]
# model_param_views.append(
# param.view(-1)[model_param_shard.start:model_param_shard.end])
# return model_param_views
# def get_local_model_grad_views(self):
# '''** FOR DEBUGGING. **'''
# model_grad_views = []
# for group_index, opt_group_shard in enumerate(self.opt_group_shards):
# for param, opt_shard in opt_group_shard["param_map"].items():
# model_index, dtype = self.param_gbuf_map[param]
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf_shard_map = \
# self.model_gbuf_shards[model_index][dtype]["param_map"][param]
# gbuf_world_shard = gbuf_shard_map["gbuf_world"]
# model_grad_views.append(
# gbuf[gbuf_world_shard.start:gbuf_world_shard.end])
# return model_grad_views
# def get_world_model_params(self):
# '''** FOR DEBUGGING. **'''
# return [ p for m in self.models for p in m.parameters() ]
# def get_world_model_grads(self):
# '''** FOR DEBUGGING. **'''
# return [ p.main_grad for p in self.get_world_model_params() ]
def
get_main_params
(
self
):
return
[
g
[
"params"
][
0
]
for
g
in
self
.
optimizer
.
param_groups
]
...
...
megatron/optimizer/optimizer.py
View file @
772a4a2d
...
...
@@ -182,7 +182,8 @@ class MegatronOptimizer(ABC):
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
class
BaseFloat16Optimizer
(
MegatronOptimizer
):
# class BaseFloat16Optimizer(MegatronOptimizer):
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
...
...
@@ -222,6 +223,10 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
@
abstractmethod
def
get_model_parallel_group
(
self
,
state_dict
):
pass
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
...
...
@@ -232,7 +237,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self
.
_copy_model_params_to_main_params
()
def
_unscale_main_grads_and_check_for_nan
(
self
):
def
_unscale_main_grads_and_check_for_nan
(
self
,
group
):
# Collect main grads.
main_grads
=
self
.
_collect_main_grad_data_for_unscaling
()
...
...
@@ -246,13 +251,14 @@ class BaseFloat16Optimizer(MegatronOptimizer):
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
# Update across all model parallel instances.
# >>>
if
args
.
use_
# >>>
# torch.distributed.all_reduce(self.found_inf,
# op=torch.distributed.ReduceOp.MAX,
# group=mpu.get_model_parallel_group())
# +++
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
get_model_parallel_group
())
# <<<
# Check for nan.
...
...
@@ -517,6 +523,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# <<<
def
get_model_parallel_group
(
self
):
return
mpu
.
get_model_parallel_group
())
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment