Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
55695f81
Commit
55695f81
authored
Mar 09, 2022
by
Lawrence McAfee
Browse files
padded ddp's grad_buffer to multiple of data parallel world size
parent
4f2356dc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
113 additions
and
78 deletions
+113
-78
megatron/model/distributed.py
megatron/model/distributed.py
+34
-4
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+79
-74
No files found.
megatron/model/distributed.py
View file @
55695f81
...
...
@@ -15,6 +15,9 @@
from
abc
import
ABC
from
abc
import
abstractmethod
# >>>
import
math
# <<<
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
...
...
@@ -27,14 +30,16 @@ from .module import MegatronModule
class
MemoryBuffer
:
def
__init__
(
self
,
numel
,
dtype
):
# >>>
def
__init__
(
self
,
numel
,
numel_padded
,
dtype
):
self
.
numel
=
numel
self
.
numel_padded
=
numel_padded
self
.
dtype
=
dtype
self
.
data
=
torch
.
zeros
(
self
.
numel
,
self
.
data
=
torch
.
zeros
(
self
.
numel
_padded
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# <<<
def
zero
(
self
):
"""Reset the buffer to zero."""
...
...
@@ -132,6 +137,7 @@ class DistributedDataParallel(DistributedDataParallelBase):
# self._grad_buffer_param_offsets = defaultdict(dict)
# self._grad_buffer_param_index_map = defaultdict(dict)
self
.
_grad_buffer_param_index_map
=
{}
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
# <<<
# Simple function to define buffer type.
...
...
@@ -149,7 +155,31 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Allocate the buffer.
for
dtype
,
num_elements
in
type_num_elements
.
items
():
self
.
_grad_buffers
[
dtype
]
=
MemoryBuffer
(
num_elements
,
dtype
)
# >>>
# If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.)
num_elements_padded
=
data_parallel_world_size
*
\
int
(
math
.
ceil
(
num_elements
/
data_parallel_world_size
))
# <<<
# Allocate grad buffer.
self
.
_grad_buffers
[
dtype
]
=
MemoryBuffer
(
num_elements
,
num_elements_padded
,
dtype
)
# >>>
# from lutil import pax
# if True or num_elements % data_parallel_world_size != 0:
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "num_elements" : num_elements,
# "num_elements_padded" : num_elements_padded,
# "modulo" : num_elements % data_parallel_world_size,
# "grad buffer" : self._grad_buffers[dtype],
# })
# <<<
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
...
...
megatron/optimizer/optimizer.py
View file @
55695f81
...
...
@@ -626,17 +626,11 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
# >>>
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
# +++
# grad_shard = optimizer.get_grad_shard(word_embeddings)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
...
...
@@ -652,15 +646,8 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
# >>>
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
# +++
# grad_shard = optimizer.get_grad_shard(
# unwrapped_model.language_model.embedding.position_embeddings.weight)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_position_embedding_group())
# <<<
timers
(
'backward-embedding-all-reduce'
).
stop
()
def
gather_params
(
self
,
ITERATION
):
...
...
@@ -717,19 +704,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
# pax(1, {"main_grads": main_grads})
# Append fp32 parameters.
for
main_group
in
self
.
fp32_from_fp32_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
# >>>
# from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
# pax(1, {"main_grads": [ (param_is_not_tensor_parallel_duplicate(t), tp(t)) for t in main_grads ]})
# <<<
return
main_grads
...
...
@@ -827,40 +807,6 @@ class Shard:
# class Float16DistributedOptimizer(MegatronOptimizer):
class
Float16DistributedOptimizer
(
BaseFloat16Optimizer
):
# >>>
# @classmethod
# def test_reduce_scatter(cls):
# torch.manual_seed(mpu.get_data_parallel_rank())
# size = (20,)
# dtype = torch.float
# device = torch.cuda.current_device()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
# input_list = [
# # torch.randn(size, dtype = dtype, device = device)
# 5 * torch.randint(low = 1, high = 3, size = size, dtype = dtype, device = device)
# for _ in range(data_parallel_world_size)
# ]
# output = torch.empty(size, dtype = dtype, device = device)
# torch.distributed.reduce_scatter(
# output,
# input_list,
# group = data_parallel_group,
# )
# if torch.distributed.get_rank() == 0:
# print(output)
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "data_parallel_group" : data_parallel_group,
# "input_list" : input_list,
# "output" : tp(output),
# })
# <<<
@
classmethod
def
get_model_gbuf_param_shard_map
(
cls
,
model
,
dtype
,
gbuf_world_shard
):
...
...
@@ -913,6 +859,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf_world_end
=
min
(
gbuf_size
,
gbuf_world_start
+
max_gbuf_shard_size
)
gbuf_world_shard
=
Shard
(
gbuf_world_start
,
gbuf_world_end
)
gbuf_world_all_shards
.
append
(
gbuf_world_shard
)
# >>>
# if max_gbuf_shard_size != gbuf_world_shard.size:
# raise Exception("%d: smaller, rank %d. [ %d -> %d vs. %d]" % (
# data_parallel_rank,
# r,
# gbuf_size,
# max_gbuf_shard_size,
# gbuf_world_shard.size,
# ))
# <<<
gbuf_world_shard
=
gbuf_world_all_shards
[
data_parallel_rank
]
gbuf_local_shard
=
gbuf_world_shard
.
normalize
()
...
...
@@ -927,9 +883,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"world"
:
gbuf_world_shard
,
"world_all"
:
gbuf_world_all_shards
,
"param_map"
:
param_shard_map
,
"max_shard_size"
:
max_gbuf_shard_size
,
}
# pax(
1
, {"data": data})
# pax(
0
, {"data": data})
return
data
...
...
@@ -992,9 +949,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
param_group_end
=
param_group_start
+
param_size
param_group_shard
=
Shard
(
param_group_start
,
param_group_end
)
# group_shard["max_size"] = gbuf_shard_map["max_shard_size"]
group_shard
[
"size"
]
+=
param_size
group_shard
[
"param_map"
][
param
]
=
param_group_shard
# pax(0, {"gbuf_shard_map": gbuf_shard_map})
# >>>
# if torch.distributed.get_rank() == 1:
# print(">>> [%d] ... group %d, size %d, param %s. <<<" % (
...
...
@@ -1010,6 +969,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group_shard
[
"orig_group"
]
=
param_groups
[
group_index
]
group_shards
=
[
g
for
g
in
group_shards
if
g
[
"size"
]
>
0
]
# [ ... x ... ] Synchronize group sizes across ranks.
# pax(0, {
# "param_group_map": [
# (g, str(p.shape))
...
...
@@ -1035,6 +996,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# main_param_shards = []
for
group_index
,
group_shard
in
enumerate
(
opt_group_shards
):
# pax(0, {
# "group_shard" : group_shard,
# })
group_size
=
group_shard
[
"size"
]
assert
group_size
!=
0
,
"temporary check ... remove me."
...
...
@@ -1075,29 +1040,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert
args
.
use_contiguous_buffers_in_local_ddp
# already checked in args
# <<<
# # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group()
# self.data_parallel_rank = mpu.get_data_parallel_rank()
# self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# Model grad buffer shards.
self
.
model_gbuf_shards
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
self
.
model_gbuf_shards
.
append
(
self
.
get_model_gbuf_shard_map
(
model
))
self
.
param_gbuf_map
=
self
.
get_param_gbuf_map
(
self
.
model_gbuf_shards
)
# pax(0, {"param_gbuf_map": [ (str(tuple(p.shape)), d) for p, d in self.param_gbuf_map.items() ]})
# Optimizer shards.
self
.
opt_group_shards
=
self
.
get_optimizer_group_shards
(
self
.
optimizer
.
param_groups
,
self
.
model_gbuf_shards
)
# pax(0, {**{"opt_group_shards / %d" % i : g for i, g in enumerate(self.opt_group_shards)}})
# Allocate main param shards.
# self.main_param_shards = \
# self.allocate_main_param_shards(self.opt_group_shards)
self
.
allocate_main_param_shards
(
self
.
opt_group_shards
)
# >>>
...
...
@@ -1205,6 +1159,37 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {"model_params": model_params})
# def get_model_grad_buffer_dp_views(self):
# # >>>
# # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
# args = get_args()
# assert args.use_contiguous_buffers_in_local_ddp
# # <<<
# # Grad buffer views.
# gbuf_view_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
# world_shards = gbuf_shard["world_all"]
# gbuf = model._grad_buffers[dtype].data
# gbuf_views = [ gbuf[s.start:s.end] for s in world_shards ]
# gbuf_view_items.append((model_index, dtype, gbuf_views))
# # pax(0, {
# # "world_shards" : world_shards,
# # "gbuf_views" : gbuf_views,
# # })
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
# return gbuf_view_items
def
get_model_grad_buffer_dp_views
(
self
):
# >>>
...
...
@@ -1213,21 +1198,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
assert
args
.
use_contiguous_buffers_in_local_ddp
# <<<
# data_parallel_rank = mpu.get_data_parallel_rank()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
# Grad buffer views.
gbuf_view_items
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
dtype
,
gbuf_shard
in
self
.
model_gbuf_shards
[
model_index
].
items
():
world_shards
=
gbuf_shard
[
"world_all"
]
gbuf
=
model
.
_grad_buffers
[
dtype
].
data
gbuf_views
=
[
gbuf
[
s
.
start
:
s
.
end
]
for
s
in
world_shards
]
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf_views
))
for
dtype
,
gbuf
in
model
.
_grad_buffers
.
items
():
# gbuf_size = gbuf.numel_padded
assert
gbuf
.
numel_padded
%
data_parallel_world_size
==
0
shard_size
=
int
(
gbuf
.
numel_padded
/
data_parallel_world_size
)
# pax(0, {
# "world_shards" : world_shards,
# "gbuf_views" : gbuf_views,
# "numel" : gbuf.numel,
# "numel_padded" : gbuf.numel_padded,
# "shard_size / f" : gbuf.numel_padded/data_parallel_world_size,
# "shard_size / i" : shard_size,
# })
gbuf_views
=
[
gbuf
.
data
[(
r
*
shard_size
):((
r
+
1
)
*
shard_size
)]
for
r
in
range
(
data_parallel_world_size
)]
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf_views
))
# pax(0, {"gbuf_view_items": gbuf_view_items})
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# **{
# "views / %d" % i : item[2]
# for i, item in enumerate(gbuf_view_items)
# },
# })
return
gbuf_view_items
...
...
@@ -1361,11 +1359,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
gbuf
/=
data_parallel_world_size
# if 1:
# try:
# pax(0, {"gbuf_views": gbuf_views})
torch
.
distributed
.
reduce_scatter
(
gbuf_views
[
data_parallel_rank
],
gbuf_views
,
group
=
data_parallel_group
,
)
# except:
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# "gbuf_views" : gbuf_views,
# })
# else:
# torch.distributed.all_reduce(
# gbuf,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment