Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a4f41882
Commit
a4f41882
authored
Feb 18, 2022
by
Lawrence McAfee
Browse files
copy model params to main params at very start
parent
ae6cd9b5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
75 additions
and
42 deletions
+75
-42
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+75
-42
No files found.
megatron/optimizer/optimizer.py
View file @
a4f41882
...
@@ -696,37 +696,37 @@ class Shard:
...
@@ -696,37 +696,37 @@ class Shard:
class
Float16DistributedOptimizer
(
BaseFloat16Optimizer
):
class
Float16DistributedOptimizer
(
BaseFloat16Optimizer
):
# >>>
# >>>
@
classmethod
#
@classmethod
def
test_reduce_scatter
(
cls
):
#
def test_reduce_scatter(cls):
torch
.
manual_seed
(
mpu
.
get_data_parallel_rank
())
#
torch.manual_seed(mpu.get_data_parallel_rank())
size
=
(
20
,)
#
size = (20,)
dtype
=
torch
.
float
#
dtype = torch.float
device
=
torch
.
cuda
.
current_device
()
#
device = torch.cuda.current_device()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
#
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
#
data_parallel_group = mpu.get_data_parallel_group()
input_list
=
[
#
input_list = [
# torch.randn(size, dtype = dtype, device = device)
#
# torch.randn(size, dtype = dtype, device = device)
5
*
torch
.
randint
(
low
=
1
,
high
=
3
,
size
=
size
,
dtype
=
dtype
,
device
=
device
)
#
5 * torch.randint(low = 1, high = 3, size = size, dtype = dtype, device = device)
for
_
in
range
(
data_parallel_world_size
)
#
for _ in range(data_parallel_world_size)
]
#
]
output
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
device
)
#
output = torch.empty(size, dtype = dtype, device = device)
torch
.
distributed
.
reduce_scatter
(
#
torch.distributed.reduce_scatter(
output
,
#
output,
input_list
,
#
input_list,
group
=
data_parallel_group
,
#
group = data_parallel_group,
)
#
)
if
torch
.
distributed
.
get_rank
()
==
0
:
#
if torch.distributed.get_rank() == 0:
print
(
output
)
#
print(output)
pax
(
0
,
{
#
pax(0, {
"data_parallel_world_size"
:
data_parallel_world_size
,
#
"data_parallel_world_size" : data_parallel_world_size,
"data_parallel_group"
:
data_parallel_group
,
#
"data_parallel_group" : data_parallel_group,
"input_list"
:
input_list
,
#
"input_list" : input_list,
"output"
:
tp
(
output
),
#
"output" : tp(output),
})
#
})
# <<<
# <<<
@
classmethod
@
classmethod
...
@@ -750,10 +750,17 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -750,10 +750,17 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
if
param_local_end
>
param_local_start
:
if
param_local_end
>
param_local_start
:
param_local_shard
=
Shard
(
param_local_start
,
param_local_end
)
param_local_shard
=
Shard
(
param_local_start
,
param_local_end
)
param_world_shard
=
param_local_shard
.
normalize
(
param_world_start
)
param_world_shard
=
param_local_shard
.
normalize
(
param_world_start
)
sub_param_start
=
max
(
0
,
gbuf_world_shard
.
start
-
param_world_start
)
sub_param_shard
=
param_local_shard
.
normalize
(
sub_param_start
)
param_shard_map
[
param
]
=
{
param_shard_map
[
param
]
=
{
"local"
:
param_local_shard
,
"gbuf_world"
:
param_world_shard
,
"world"
:
param_world_shard
,
"gbuf_local"
:
param_local_shard
,
"param"
:
sub_param_shard
,
}
}
# >>>
if
param_world_start
<
gbuf_world_shard
.
start
:
raise
Exception
(
"hi."
)
# <<<
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
...
@@ -798,17 +805,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -798,17 +805,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
@
classmethod
@
classmethod
def
get_model_gbuf_shard_map
(
cls
,
model
):
def
get_model_gbuf_shard_map
(
cls
,
model
):
return
{
# shard_index_map = {
shard_map
=
{
dtype
:
cls
.
get_model_gbuf_shard
(
model
,
dtype
)
dtype
:
cls
.
get_model_gbuf_shard
(
model
,
dtype
)
for
dtype
in
model
.
_grad_buffers
for
dtype
in
model
.
_grad_buffers
}
}
# pax(0, {"shard_map": shard_map})
return
shard_map
@
classmethod
@
classmethod
def
get_param_gbuf_map
(
cls
,
model_gbuf_shards
):
def
get_param_gbuf_map
(
cls
,
model_gbuf_shards
):
...
@@ -855,7 +856,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -855,7 +856,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group_index
=
param_group_map
[
param
]
group_index
=
param_group_map
[
param
]
group_shard
=
group_shards
[
group_index
]
group_shard
=
group_shards
[
group_index
]
param_size
=
gbuf_shard_map
[
"param_map"
][
param
][
"
local
"
].
size
param_size
=
gbuf_shard_map
[
"param_map"
][
param
][
"
param
"
].
size
param_group_start
=
group_shard
[
"size"
]
param_group_start
=
group_shard
[
"size"
]
param_group_end
=
param_group_start
+
param_size
param_group_end
=
param_group_start
+
param_size
...
@@ -1055,10 +1056,42 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1055,10 +1056,42 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def
_collect_main_grad_data_for_unscaling
(
self
):
def
_collect_main_grad_data_for_unscaling
(
self
):
return
[
p
.
grad
.
data
for
p
in
self
.
main_param_shards
]
return
[
p
.
grad
.
data
for
p
in
self
.
main_param_shards
]
def
_copy_model_params_to_main_params
(
self
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
main_param
=
self
.
main_param_shards
[
group_index
]
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
# Model shard.
model_index
,
dtype
=
self
.
param_gbuf_map
[
model_param
]
model_shard
=
self
.
model_gbuf_shards
\
[
model_index
][
dtype
][
"param_map"
][
model_param
][
"param"
]
assert
main_shard
.
size
==
model_shard
.
size
# Copy shard data.
main_view
=
main_param
[
main_shard
.
start
:
main_shard
.
end
]
model_view
=
model_param
[
model_shard
.
start
:
model_shard
.
end
].
view
(
-
1
)
main_view
.
detach
().
copy_
(
model_view
)
# pax(0, {
# "main_param" : tp(main_param),
# "model_param" : tp(model_param),
# "main_view" : tp(main_view),
# "model_view" : tp(model_view),
# "main_shard" : str(main_shard),
# "model_shard" : str(model_shard),
# })
pax
(
0
,
{
"opt_group_shards"
:
self
.
opt_group_shards
,
"main_param_shards"
:
self
.
main_param_shards
,
})
def
_copy_model_grads_to_main_grads
(
self
):
def
_copy_model_grads_to_main_grads
(
self
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
for
model_
param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
model_index
,
gbuf_dtype
=
self
.
param_gbuf_map
[
param
]
model_index
,
gbuf_dtype
=
self
.
param_gbuf_map
[
param
]
model_shard
=
self
.
model_gbuf_shards
\
model_shard
=
self
.
model_gbuf_shards
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment