Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a3f3c3ad
Commit
a3f3c3ad
authored
Feb 14, 2022
by
Lawrence McAfee
Browse files
todo; align shards with model's contiguous buffer
parent
3f0bc681
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
139 additions
and
104 deletions
+139
-104
megatron/model/distributed.py
megatron/model/distributed.py
+1
-1
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+138
-103
No files found.
megatron/model/distributed.py
View file @
a3f3c3ad
...
...
@@ -122,7 +122,7 @@ class DistributedDataParallel(DistributedDataParallelBase):
# ===================================
self
.
_grad_buffers
=
None
# >>>
from
collections
import
defaultdict
#
from collections import defaultdict
# self._grad_buffer_param_offsets = None
self
.
_grad_buffer_param_index_map
=
None
# <<<
...
...
megatron/optimizer/optimizer.py
View file @
a3f3c3ad
...
...
@@ -770,35 +770,35 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# (Also, collect world DP shard info.)
# model_main_dtypes = set([ args.params_dtype, torch.float ])
model_main_dtypes
=
set
([
torch
.
float
])
# fp32 only, for now
self
.
world_shard_info_groups
=
[]
# world_group_shard_infos ?
#
self.world_shard_info_groups = [] # world_group_shard_infos ?
# self.main_param_shard_groups = []
self
.
world_shard_infos
=
[{
"groups"
:
[]}
for
_
in
self
.
model_param_groups
]
for
group_index
,
model_param_group
in
enumerate
(
self
.
model_param_groups
):
# pax(0, {
# "model_param_group" : model_param_group,
# "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()],
# })
# Max world shard size.
model_param_size
=
model_param_group
[
"size"
]
max_world_shard_size
=
int
(
math
.
ceil
(
model_param_size
/
self
.
data_parallel_world_size
))
# DP world shard infos.
world_shard_infos
=
[]
#
world_shard_infos = []
for
r
in
range
(
self
.
data_parallel_world_size
):
shard_start_index
=
r
*
max_world_shard_size
shard_end_index
=
min
(
model_param_size
,
shard_start_index
+
max_world_shard_size
)
world_shard_infos
.
append
({
# world_shard_infos.append({
self
.
world_shard_infos
[
r
][
"groups"
].
append
({
"start"
:
shard_start_index
,
"end"
:
shard_end_index
,
"size"
:
shard_end_index
-
shard_start_index
,
})
self
.
world_shard_info_groups
.
append
(
world_shard_infos
)
# self.world_shard_info_groups.append(world_shard_infos)
# self.world_shard_infos[group_index].append(world_shard_infos)
# DP local rank's shard info.
local_shard_info
=
world_shard_infos
[
self
.
data_parallel_rank
]
# local_shard_info = world_shard_infos[self.data_parallel_rank]
local_shard_info
=
\
self
.
world_shard_infos
[
self
.
data_parallel_rank
][
"groups"
][
-
1
]
local_shard_start_index
=
local_shard_info
[
"start"
]
local_shard_end_index
=
local_shard_info
[
"end"
]
local_shard_size
=
local_shard_info
[
"size"
]
...
...
@@ -895,12 +895,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "params" : self.optimizer.param_groups[group_index]["params"],
# })
# Add world start/end indexes, for reduce/gather steps.
offset
=
0
for
r
in
self
.
world_shard_infos
:
r
[
"start_index"
]
=
offset
offset
+=
sum
(
g
[
"size"
]
for
g
in
r
[
"groups"
])
r
[
"end_index"
]
=
offset
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# >>>
# pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# pax(0, {
# "world_shard_infos" : self.world_shard_infos,
# **{
# "world_shard_infos / %d" % i : r
# for i, r in enumerate(self.world_shard_infos)
# },
# })
# <<<
# def get_loss_scale(self):
...
...
@@ -931,107 +944,129 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "params" : params,
# })
def
reduce_gradients
(
self
,
model
):
# def reduce_gradients(self, model):
# # >>>
# # pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# # <<<
# # >>>
# args = get_args()
# # timers = get_timers()
# # <<<
# # >>> [ temporary requirement ... and already checked in arguments.py ]
# assert args.use_contiguous_buffers_in_local_ddp
# # <<<
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# # Map param to virtual model.
# # ** ideally, this should happen once, during construction.
# param_model_map = {}
# for vmodel in model:
# for dtype, param_index_map in \
# vmodel._grad_buffer_param_index_map.items():
# for param in param_index_map:
# param_model_map[param] = {
# "dtype" : dtype,
# "model" : vmodel,
# }
# # pax(0, {
# # "param_model_map" : [
# # (str(tuple(p.shape)), m)
# # for p, m in param_model_map.items()
# # ],
# # })
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# # Copy model grads to main shard.
# local_shard_info_groups = [g[self.data_parallel_rank]
# for g in self.world_shard_info_groups]
# for group_index, local_shard_info in enumerate(local_shard_info_groups):
# # model_param_index_map =
# # shard_param_index_map = local_shard_info["param_index_map"]
# # main_index_map = local_shard_info["param_index_map"]
# main_slice_index_map = local_shard_info["param_slice_index_map"]
# for param, main_slice_indexes in main_slice_index_map.items():
# main_slice_orig_start_index = main_slice_indexes["orig_start"]
# main_slice_shard_start_index = main_slice_indexes["shard_start"]
# main_slice_size = main_slice_indexes["size"]
# dtype_model_dict = param_model_map[param]
# dtype = dtype_model_dict["dtype"]
# vmodel = dtype_model_dict["model"]
# model_grad_buffer = vmodel._grad_buffers[dtype].data
# model_grad_buffer_start_index = \
# vmodel._grad_buffer_param_index_map[dtype][param][0] + \
# main_slice_orig_start_index
# main_grad_view = local_shard_info["data"][torch.float].grad[
# main_slice_shard_start_index:
# main_slice_shard_start_index + main_slice_size
# ]
# model_grad_view = model_grad_buffer[
# model_grad_buffer_start_index:
# model_grad_buffer_start_index + main_slice_size
# ]
# main_grad_view.detach().copy_(model_grad_view)
# # pax(0, {
# # # "local_shard_info" : local_shard_info,
# # "main_slice_orig_start_index" : main_slice_orig_start_index,
# # "main_slice_shard_start_index" : main_slice_shard_start_index,
# # "main_slice_size" : main_slice_size,
# # "model_grad_buffer_start_index" :
# # model_grad_buffer_start_index,
# # "main_grad_view" : tp(main_grad_view),
# # "main_grad_view / detach" : tp(main_grad_view.detach()),
# # "model_grad_view" : tp(model_grad_view),
# # })
# # pax(0, {
# # "group_index" : group_index,
# # "local_shard_info" : local_shard_info,
# # "shard_param_index_map" : shard_param_index_map,
# # "param" : tp(param),
# # "shard_indexes" : shard_indexes,
# # "grad_buffer_indexes" : grad_buffer_indexes,
# # })
# pax(0, {
# # "world_shard_info_groups" : self.world_shard_info_groups,
# # **{"world_shard_info_groups / %d" % i : v
# # for i, v in enumerate(self.world_shard_info_groups)},
# # "local_shard_info_groups" : local_shard_info_groups,
# "local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ],
# })
# >>>
# pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# <<<
def
reduce_gradients
(
self
,
model
):
# >>>
args
=
get_args
()
# timers = get_timers()
# <<<
# >>> [ temporary requirement ... and already checked in arguments.py ]
assert
args
.
use_contiguous_buffers_in_local_ddp
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Map param to virtual model.
# ** ideally, this should happen once, during construction.
param_model_map
=
{}
for
vmodel
in
model
:
for
dtype
,
param_index_map
in
\
vmodel
.
_grad_buffer_param_index_map
.
items
():
for
param
in
param_index_map
:
param_model_map
[
param
]
=
{
"dtype"
:
dtype
,
"model"
:
vmodel
,
}
# pax(0, {
# "param_model_map" : [
# (str(tuple(p.shape)), m)
# for p, m in param_model_map.items()
# ],
# })
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard.
local_shard_info_groups
=
[
g
[
self
.
data_parallel_rank
]
for
g
in
self
.
world_shard_info_groups
]
for
group_index
,
local_shard_info
in
enumerate
(
local_shard_info_groups
):
# model_param_index_map =
# shard_param_index_map = local_shard_info["param_index_map"]
# main_index_map = local_shard_info["param_index_map"]
main_slice_index_map
=
local_shard_info
[
"param_slice_index_map"
]
for
param
,
main_slice_indexes
in
main_slice_index_map
.
items
():
main_slice_orig_start_index
=
main_slice_indexes
[
"orig_start"
]
main_slice_shard_start_index
=
main_slice_indexes
[
"shard_start"
]
main_slice_size
=
main_slice_indexes
[
"size"
]
dtype_model_dict
=
param_model_map
[
param
]
dtype
=
dtype_model_dict
[
"dtype"
]
vmodel
=
dtype_model_dict
[
"model"
]
model_grad_buffer
=
vmodel
.
_grad_buffers
[
dtype
].
data
model_grad_buffer_start_index
=
\
vmodel
.
_grad_buffer_param_index_map
[
dtype
][
param
][
0
]
+
\
main_slice_orig_start_index
main_grad_view
=
local_shard_info
[
"data"
][
torch
.
float
].
grad
[
main_slice_shard_start_index
:
main_slice_shard_start_index
+
main_slice_size
]
model_grad_view
=
model_grad_buffer
[
model_grad_buffer_start_index
:
model_grad_buffer_start_index
+
main_slice_size
]
main_grad_view
.
detach
().
copy_
(
model_grad_view
)
# pax(0, {
# # "local_shard_info" : local_shard_info,
# "main_slice_orig_start_index" : main_slice_orig_start_index,
# "main_slice_shard_start_index" : main_slice_shard_start_index,
# "main_slice_size" : main_slice_size,
# "model_grad_buffer_start_index" :
# model_grad_buffer_start_index,
# "main_grad_view" : tp(main_grad_view),
# "main_grad_view / detach" : tp(main_grad_view.detach()),
# "model_grad_view" : tp(model_grad_view),
# })
# pax(0, {
# "group_index" : group_index,
# "local_shard_info" : local_shard_info,
# "shard_param_index_map" : shard_param_index_map,
# "param" : tp(param),
# "shard_indexes" : shard_indexes,
# "grad_buffer_indexes" : grad_buffer_indexes,
# })
pax
(
0
,
{
# "world_shard_info_groups" : self.world_shard_info_groups,
# **{"world_shard_info_groups / %d" % i : v
# for i, v in enumerate(self.world_shard_info_groups)},
# "local_shard_info_groups" : local_shard_info_groups,
"local_shard_info_groups"
:
[
g
[
"data"
]
for
g
in
local_shard_info_groups
],
})
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert
args
.
use_contiguous_buffers_in_local_ddp
world_sizes
=
[]
for
r
in
self
.
world_shard_infos
:
# world_sizes.append(sum(g["size"] for g in r))
world_sizes
.
append
([
g
[
"size"
]
for
g
in
r
[
"groups"
]
])
# grad_refs ...
pax
(
0
,
{
"world_sizes"
:
world_sizes
})
# for world_grads = []
# for world_shard_info_group
# x ?
raise
Exception
(
"hi."
)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment