Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
23f9238d
Commit
23f9238d
authored
Feb 22, 2022
by
Lawrence McAfee
Browse files
set main_param_shard==None if local group_size==0 [ e.g., word embedding params
parent
ac5ef637
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
45 deletions
+63
-45
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+63
-45
No files found.
megatron/optimizer/optimizer.py
View file @
23f9238d
...
@@ -293,7 +293,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
...
@@ -293,7 +293,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# If we found inf/nan, skip the update.
# If we found inf/nan, skip the update.
if
found_inf_flag
:
if
found_inf_flag
:
pax
(
0
,
{
"found_inf_flag"
:
found_inf_flag
})
#
pax(0, {"found_inf_flag": found_inf_flag})
return
False
,
None
,
None
return
False
,
None
,
None
# Clip the main gradients.
# Clip the main gradients.
...
@@ -758,8 +758,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -758,8 +758,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"param"
:
sub_param_shard
,
"param"
:
sub_param_shard
,
}
}
# >>>
# >>>
if
param_world_start
<
gbuf_world_shard
.
start
:
#
if param_world_start < gbuf_world_shard.start:
raise
Exception
(
"hi."
)
#
pax({"param shards": param_shard_map[param]}
)
# <<<
# <<<
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
...
@@ -865,13 +865,23 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -865,13 +865,23 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group_shard
[
"size"
]
+=
param_size
group_shard
[
"size"
]
+=
param_size
group_shard
[
"param_map"
][
param
]
=
param_group_shard
group_shard
[
"param_map"
][
param
]
=
param_group_shard
# raise Exception("hi.")
# >>>
# if torch.distributed.get_rank() == 1:
# pax(0, {"param_group_map": [
# print(">>> [%d] ... group %d, size %d, param %s. <<<" % (
# (g, str(p.shape))
# torch.distributed.get_rank(),
# for p, g in param_group_map.items()
# group_index,
# ]})
# param_size,
# pax(0, {"group_shards": group_shards})
# str(tuple(param.shape)),
# ))
# <<<
# pax(1, {
# "param_group_map": [
# (g, str(p.shape))
# for p, g in param_group_map.items()
# ],
# "group_shards" : group_shards,
# })
return
group_shards
return
group_shards
...
@@ -913,7 +923,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -913,7 +923,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
dtype
=
dtype
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
True
)
requires_grad
=
True
)
self
.
main_param_shards
=
[]
self
.
main_param_shards
=
[]
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
...
@@ -922,14 +932,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -922,14 +932,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ** todo: for dtype in model_main_dtypes ........ **
# ** todo: for dtype in model_main_dtypes ........ **
# Allocate shard.
# Allocate shard.
main_param
=
allocate_shard
(
group_size
,
torch
.
float
)
if
group_size
==
0
:
main_param
.
grad
=
allocate_shard
(
group_size
,
torch
.
float
)
main_param
=
None
else
:
main_param
=
allocate_shard
(
group_size
,
torch
.
float
)
main_param
.
grad
=
allocate_shard
(
group_size
,
torch
.
float
)
mpu
.
set_tensor_model_parallel_attributes
(
main_param
,
True
,
0
,
1
)
self
.
main_param_shards
.
append
(
main_param
)
self
.
main_param_shards
.
append
(
main_param
)
mpu
.
set_tensor_model_parallel_attributes
(
main_param
,
True
,
0
,
1
)
# Update optimizer group.
# Update optimizer group.
self
.
optimizer
.
param_groups
[
group_index
][
"params"
]
=
[
main_param
]
self
.
optimizer
.
param_groups
[
group_index
][
"params"
]
=
[
main_param
]
# >>>
pax
(
0
,
{
"model_gbuf_shards"
:
self
.
model_gbuf_shards
,
"opt_group_shards"
:
self
.
opt_group_shards
,
"main_param_shards"
:
self
.
main_param_shards
,
})
# <<<
# Initialize main params.
# Initialize main params.
self
.
_copy_model_params_to_main_params
()
self
.
_copy_model_params_to_main_params
()
...
@@ -937,13 +958,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -937,13 +958,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# recast preexisting per-param state tensors
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# >>>
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# "main_param_shards" : self.main_param_shards,
# })
# <<<
def
load_state_dict
(
self
):
def
load_state_dict
(
self
):
raise
Exception
(
"hi."
)
raise
Exception
(
"hi."
)
...
@@ -1071,22 +1085,26 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1071,22 +1085,26 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Copy shard data.
# Copy shard data.
main_view
=
main_param
[
main_shard
.
start
:
main_shard
.
end
]
main_view
=
main_param
[
main_shard
.
start
:
main_shard
.
end
]
model_view
=
model_param
[
model_shard
.
start
:
model_shard
.
end
].
view
(
-
1
)
model_view
=
model_param
.
view
(
-
1
)[
model_shard
.
start
:
model_shard
.
end
]
# try:
main_view
.
detach
().
copy_
(
model_view
)
main_view
.
detach
().
copy_
(
model_view
)
# except:
# pax(0, {
# pax({
# "main_param" : tp(main_param),
# "main_param" : tp(main_param),
# "model_param" : tp(model_param),
# "model_param" : tp(model_param),
# "main_view" : tp(main_view),
# "main_view" : tp(main_view),
# "model_view" : tp(model_view),
# "model_view" : tp(model_view),
# "main_shard" : str(main_shard),
# "main_shard" : str(main_shard),
# "model_shard" : str(model_shard),
# "model_shard" : str(model_shard),
# })
# })
# pax(0, {
pax
(
1
,
{
# "opt_group_shards" : self.opt_group_shards,
**
{
# "main_param_shards" : self.main_param_shards,
"opt_group_shards / %d"
%
i
:
s
# })
for
i
,
s
in
enumerate
(
self
.
opt_group_shards
)
},
"main_param_shards"
:
self
.
main_param_shards
,
})
def
_copy_model_grads_to_main_grads
(
self
):
def
_copy_model_grads_to_main_grads
(
self
):
...
@@ -1128,14 +1146,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1128,14 +1146,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "model_gbuf_shards" : self.model_gbuf_shards,
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
# })
for
param
in
self
.
main_param_shards
:
#
for param in self.main_param_shards:
grad
=
param
.
grad
#
grad = param.grad
is_nan
=
torch
.
any
(
torch
.
isnan
(
grad
)).
item
()
#
is_nan = torch.any(torch.isnan(grad)).item()
if
is_nan
:
#
if is_nan:
pax
(
0
,
{
#
pax(0, {
"grad"
:
tp
(
grad
),
#
"grad" : tp(grad),
"is_nan"
:
is_nan
,
#
"is_nan" : is_nan,
})
#
})
# <<<
# <<<
...
@@ -1183,7 +1201,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1183,7 +1201,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for
param
in
self
.
param_gbuf_map
:
for
param
in
self
.
param_gbuf_map
:
is_nan
=
torch
.
any
(
torch
.
isnan
(
param
)).
item
()
is_nan
=
torch
.
any
(
torch
.
isnan
(
param
)).
item
()
if
is_nan
:
if
is_nan
:
pax
(
0
,
{
pax
({
"param"
:
tp
(
param
),
"param"
:
tp
(
param
),
"is_nan"
:
is_nan
,
"is_nan"
:
is_nan
,
})
})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment