Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1215c420
Commit
1215c420
authored
Feb 14, 2022
by
Lawrence McAfee
Browse files
tweaked slice index naming convention
parent
c5f93269
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
78 additions
and
27 deletions
+78
-27
megatron/model/distributed.py
megatron/model/distributed.py
+8
-4
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+70
-23
No files found.
megatron/model/distributed.py
View file @
1215c420
...
@@ -164,10 +164,14 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -164,10 +164,14 @@ class DistributedDataParallel(DistributedDataParallelBase):
# type_num_elements[dtype]
# type_num_elements[dtype]
if
dtype
not
in
self
.
_grad_buffer_param_index_map
:
if
dtype
not
in
self
.
_grad_buffer_param_index_map
:
self
.
_grad_buffer_param_index_map
[
dtype
]
=
{}
self
.
_grad_buffer_param_index_map
[
dtype
]
=
{}
self
.
_grad_buffer_param_index_map
[
dtype
][
param
]
=
{
# self._grad_buffer_param_index_map[dtype][param] = {
"start"
:
type_num_elements
[
dtype
],
# "start" : type_num_elements[dtype],
"end"
:
param
.
data
.
nelement
(),
# "end" : type_num_elements[dtype] + param.data.nelement(),
}
# }
self
.
_grad_buffer_param_index_map
[
dtype
][
param
]
=
(
type_num_elements
[
dtype
],
type_num_elements
[
dtype
]
+
param
.
data
.
nelement
(),
)
# <<<
# <<<
# Backward hook.
# Backward hook.
...
...
megatron/optimizer/optimizer.py
View file @
1215c420
...
@@ -802,27 +802,48 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -802,27 +802,48 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
local_shard_end_index
=
local_shard_info
[
"end"
]
local_shard_end_index
=
local_shard_info
[
"end"
]
local_shard_size
=
local_shard_info
[
"size"
]
local_shard_size
=
local_shard_info
[
"size"
]
# Local shard's param index map.
# Local shard's param
'slice'
index map.
local_shard_info
[
"param_index_map"
]
=
{}
local_shard_info
[
"param_
slice_
index_map"
]
=
{}
for
param
,
offset_dict
in
model_param_group
[
"offset_map"
].
items
():
for
param
,
offset_dict
in
model_param_group
[
"offset_map"
].
items
():
param_start_index
=
offset_dict
[
"start"
]
# param_start_index = offset_dict["start"]
param_end_index
=
offset_dict
[
"end"
]
# param_end_index = offset_dict["end"]
param_shard_start_index
=
max
(
local_shard_start_index
,
# param_shard_start_index = max(local_shard_start_index,
param_start_index
)
# param_start_index)
param_shard_end_index
=
min
(
local_shard_end_index
,
# param_shard_end_index = min(local_shard_end_index,
param_end_index
)
# param_end_index)
orig_start_index
=
offset_dict
[
"start"
]
orig_end_index
=
offset_dict
[
"end"
]
shard_start_index
=
max
(
0
,
orig_start_index
-
local_shard_start_index
)
shard_end_index
=
min
(
local_shard_end_index
,
orig_end_index
-
local_shard_start_index
)
if
param_shard_end_index
>
param_shard_start_index
:
if
param_shard_end_index
>
param_shard_start_index
:
# Indexes are relative to local shard start index.
# Indexes are relative to local shard start index.
local_shard_info
[
"param_index_map"
][
param
]
=
{
# local_shard_info["param_index_map"][param] = {
"param"
:
(
# "param" : (
param_shard_start_index
,
# param_shard_start_index,
param_shard_end_index
,
# param_shard_end_index,
),
# ),
"shard"
:
(
# "shard" : (
param_shard_start_index
-
local_shard_start_index
,
# param_shard_start_index - local_shard_start_index,
param_shard_end_index
-
local_shard_start_index
,
# param_shard_end_index - local_shard_start_index,
),
# ),
# }
# local_shard_info["param_slice_index_map"][param] = {
# "param_start" :
# param_shard_start_index,
# "shard_start" :
# param_shard_start_index - local_shard_start_index,
# "size":
# param_shard_end_index - param_shard_start_index,
# }
local_shard_info
[
"param_slice_index_map"
][
param
]
=
{
"orig_start"
:
orig_start_index
,
"shard_start"
:
shard_start_index
,
"size"
:
shard_end_index
-
shard_start_index
,
}
}
# pax(0, {
# pax(0, {
...
@@ -854,7 +875,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -854,7 +875,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
main_param_shards
=
{
main_param_shards
=
{
ty
:
allocate_shard
(
local_shard_size
,
ty
)
ty
:
allocate_shard
(
local_shard_size
,
ty
)
for
ty
in
model_main_dtypes
}
for
ty
in
model_main_dtypes
}
self
.
main_param_shard_groups
.
append
(
main_param_shards
)
# self.main_param_shard_groups.append(main_param_shards)
local_shard_info
[
"data"
]
=
main_param_shards
# Update optimizer group.
# Update optimizer group.
self
.
optimizer
.
param_groups
[
group_index
][
"params"
]
=
\
self
.
optimizer
.
param_groups
[
group_index
][
"params"
]
=
\
...
@@ -935,16 +957,41 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -935,16 +957,41 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for
group_index
,
local_shard_info
in
enumerate
(
local_shard_info_groups
):
for
group_index
,
local_shard_info
in
enumerate
(
local_shard_info_groups
):
# model_param_index_map =
# model_param_index_map =
shard_param_index_map
=
local_shard_info
[
"param_index_map"
]
# shard_param_index_map = local_shard_info["param_index_map"]
for
param
,
shard_indexes
in
shard_param_index_map
.
items
():
# main_index_map = local_shard_info["param_index_map"]
main_slice_index_map
=
local_shard_info
[
"param_slice_index_map"
]
for
param
,
main_slice_indexes
in
main_slice_index_map
.
items
():
main_param_start_index
=
main_slice_indexes
[
"param_start"
]
main_shard_start_index
=
main_slice_indexes
[
"shard_start"
]
main_slice_size
=
ddd
main_size
=
main_shard_indexesddd
dtype_model_dict
=
param_model_map
[
param
]
dtype_model_dict
=
param_model_map
[
param
]
dtype
=
dtype_model_dict
[
"dtype"
]
dtype
=
dtype_model_dict
[
"dtype"
]
vmodel
=
dtype_model_dict
[
"model"
]
vmodel
=
dtype_model_dict
[
"model"
]
grad_buffer_indexes
=
\
model_grad_buffer
=
vmodel
.
_grad_buffers
[
dtype
]
vmodel
.
_grad_buffer_param_index_map
[
dtype
][
param
]
model_grad_buffer_start_index
=
\
vmodel
.
_grad_buffer_param_index_map
[
dtype
][
param
][
0
]
# model_grad_buffer_indexes = [ model_grad_buffer_start_index + i
# for i in main_
# model_grad_view = model_grad_buffer.data[
pax
(
0
,
{
"model_grad_buffer_indexes"
:
model_grad_buffer_indexes
})
main_grad_view
=
self
.
main_param_shard_groups
\
[
group_index
][
torch
.
float
].
grad
\
[
shard_indexes
[
"shard"
][
0
]:
shard_indexes
[
"shard"
][
1
]]
pax
(
0
,
{
"dtype"
:
dtype
})
pax
(
0
,
{
# "dtype" : dtype,
# "vmodel" : vmodel,
"shard_indexes"
:
shard_indexes
,
"grad_buffer_indexes"
:
grad_buffer_indexes
,
"model_grad_view"
:
model_grad_view
,
"main_grad_views"
:
main_grad_view
,
})
pax
(
0
,
{
pax
(
0
,
{
"group_index"
:
group_index
,
"group_index"
:
group_index
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment