Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4b9a218e
"server/vscode:/vscode.git/clone" did not exist on "99874eae7490c4a6aa90fe55f183b58d88397c8f"
Commit
4b9a218e
authored
Mar 24, 2022
by
Lawrence McAfee
Browse files
losses match.
parent
41ad0c5d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
103 deletions
+2
-103
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+2
-103
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
4b9a218e
...
@@ -188,11 +188,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -188,11 +188,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_group_map
[
param
]
=
group_index
param_group_map
[
param
]
=
group_index
# Optimizer group ranges.
# Optimizer group ranges.
# >>>
# group_ranges = [ {"size": 0, "param_map": {}} for _ in param_groups ]
group_ranges
=
[
{
"params"
:
[]}
for
_
in
param_groups
]
group_ranges
=
[
{
"params"
:
[]}
for
_
in
param_groups
]
# group_ranges = [ [] for _ in param_groups ]
# <<<
for
model_gbuf_range_map
in
model_gbuf_ranges
:
for
model_gbuf_range_map
in
model_gbuf_ranges
:
for
dtype
,
gbuf_range_map
in
model_gbuf_range_map
.
items
():
for
dtype
,
gbuf_range_map
in
model_gbuf_range_map
.
items
():
for
param
in
gbuf_range_map
[
"param_map"
]:
for
param
in
gbuf_range_map
[
"param_map"
]:
...
@@ -205,41 +201,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -205,41 +201,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group_range
[
"orig_group"
]
=
param_groups
[
group_index
]
group_range
[
"orig_group"
]
=
param_groups
[
group_index
]
group_ranges
=
[
g
for
g
in
group_ranges
if
len
(
g
[
"params"
])
>
0
]
group_ranges
=
[
g
for
g
in
group_ranges
if
len
(
g
[
"params"
])
>
0
]
# >>>
# print_seq("group ranges / len = %s." %
# ", ".join(str(len(s["params"])) for s in group_ranges))
# <<<
return
group_ranges
return
group_ranges
# <<<
# >>>
# @classmethod
# def allocate_main_param_shards(cls, opt_group_ranges):
# # Allocator method.
# allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# # Allocate each group's param/grad shard.
# for group_index, group_range in enumerate(opt_group_ranges):
# group_size = group_range["size"]
# assert group_size != 0, "temporary check ... remove me."
# # Allocate shard.
# main_param = allocate_shard(group_size, torch.float)
# main_param.grad = allocate_shard(group_size, torch.float)
# mpu.set_tensor_model_parallel_attributes(main_param, True, 0, 1)
# # Update group's param.
# group_range["orig_group"]["params"] = [ main_param ]
@
classmethod
@
classmethod
# def allocate_main_params(cls, opt_group_ranges):
# def allocate_or_view_main_param_shards(cls,
def
build_model_and_main_param_groups
(
cls
,
def
build_model_and_main_param_groups
(
cls
,
model_gbuf_ranges
,
model_gbuf_ranges
,
param_gbuf_map
,
param_gbuf_map
,
...
@@ -255,7 +220,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -255,7 +220,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_fp32_groups
=
[]
shard_fp32_groups
=
[]
shard_fp32_from_float16_groups
=
[]
shard_fp32_from_float16_groups
=
[]
# Allocate each group's param shard.
# Allocate
(or slice)
each group's param shard.
for
group_index
,
group_range
in
enumerate
(
opt_group_ranges
):
for
group_index
,
group_range
in
enumerate
(
opt_group_ranges
):
# Params of this group.
# Params of this group.
...
@@ -277,10 +242,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -277,10 +242,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_range
=
model_gbuf_ranges
[
model_index
][
dtype
]
gbuf_range
=
model_gbuf_ranges
[
model_index
][
dtype
]
param_range
=
gbuf_range
[
"param_map"
][
model_param
][
"param"
]
param_range
=
gbuf_range
[
"param_map"
][
model_param
][
"param"
]
# >>>
assert
param_range
.
size
>
0
# <<<
# fp16, bf16 params.
# fp16, bf16 params.
if
model_param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
if
model_param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
'torch.cuda.BFloat16Tensor'
]:
...
@@ -297,11 +258,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -297,11 +258,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param
.
shared
=
model_param
.
shared
shard_model_param
.
shared
=
model_param
.
shared
shard_main_param
.
shared
=
model_param
.
shared
shard_main_param
.
shared
=
model_param
.
shared
# >>>
assert
shard_main_param
.
nelement
()
>
0
,
\
"param_range = %s."
%
param_range
# <<<
# Add to group.
# Add to group.
full_float16_params_this_group
.
append
(
model_param
)
full_float16_params_this_group
.
append
(
model_param
)
shard_float16_params_this_group
.
append
(
shard_model_param
)
shard_float16_params_this_group
.
append
(
shard_model_param
)
...
@@ -321,9 +277,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -321,9 +277,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'torch.cuda.BFloat16Tensor. '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
'Received {}'
.
format
(
param
.
type
()))
# # Add to group.
# Update optimizer's params.
# group_main_params.append(main_param)
group_range
[
"orig_group"
][
"params"
]
=
[
group_range
[
"orig_group"
][
"params"
]
=
[
*
shard_fp32_params_this_group
,
*
shard_fp32_params_this_group
,
*
shard_fp32_from_float16_params_this_group
,
*
shard_fp32_from_float16_params_this_group
,
...
@@ -336,24 +290,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -336,24 +290,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_fp32_groups
,
shard_fp32_groups
,
shard_fp32_from_float16_groups
,
shard_fp32_from_float16_groups
,
)
)
# <<<
# >>>
# @classmethod
# def build_main_grad_views_for_grad_norm(cls, opt_group_ranges, optimizer):
# grad_views = []
# for group_index, opt_group_range in enumerate(opt_group_ranges):
# opt_grad = optimizer.param_groups[group_index]["params"][0].grad
# for param, range in opt_group_range["param_map"].items():
# if param_is_not_shared(param) and \
# param_is_not_tensor_parallel_duplicate(param):
# grad_view = opt_grad[range.start:range.end]
# grad_views.append(grad_view)
# return grad_views
# <<<
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
...
@@ -702,43 +638,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -702,43 +638,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# <<<
# <<<
# >>>
# >>>
# def _copy_main_params_to_model_params(self):
# for group_index, group_range in enumerate(self.opt_group_ranges):
# for model_param, main_range in group_range["param_map"].items():
# # model_index, dtype = self.param_gbuf_map[model_param]
# # model_range = self.model_gbuf_ranges \
# # [model_index][dtype]["param_map"][model_param]["gbuf_world"]
# model_range = self.get_model_param_range_map(model_param)["gbuf_world"]
# assert main_range.size == model_range.size
# # Use DDP's contiguous buffer to temporarily hold params.
# model_param = self.models[model_index]._grad_buffers[dtype].data
# main_param = self.get_main_param(group_index)
# # Copy sub-range within tensor.
# model_view = model_param[model_range.start:model_range.end]
# main_view = main_param[main_range.start:main_range.end]
# model_view.detach().copy_(main_view)
# def _copy_main_params_to_model_params(self):
# super()._copy_main_params_to_model_params()
# raise Exception("check main param '.grad'.")
# def _copy_main_params_to_model_params(self):
# raise Exception("hi.")
# # This only needs to be done for the float16 group.
# for model_group, main_group in zip(self.float16_groups,
# self.fp32_from_float16_groups):
# for model_param, main_param in zip(model_group, main_group):
# model_param.main_grad.detach().copy_(main_param)
# # For fp32 grads, we need to reset the grads to main grad.
# for group in self.fp32_groups:
# for param in group:
# param.main_grad.detach().copy_(param)
def
_copy_main_params_to_model_params
(
self
):
def
_copy_main_params_to_model_params
(
self
):
def
copy_group_params
(
shard_main_groups
,
full_model_groups
):
def
copy_group_params
(
shard_main_groups
,
full_model_groups
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment