Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c5f93269
"src/turbomind/vscode:/vscode.git/clone" did not exist on "911c0a85dadbf1783940138d7c6aafdbc88d6a17"
Commit
c5f93269
authored
Feb 14, 2022
by
Lawrence McAfee
Browse files
map param to originating virtual model; eventually move this to constructor
parent
3ded2425
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
9 deletions
+63
-9
megatron/model/distributed.py
megatron/model/distributed.py
+17
-4
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+46
-5
No files found.
megatron/model/distributed.py
View file @
c5f93269
...
@@ -123,11 +123,16 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -123,11 +123,16 @@ class DistributedDataParallel(DistributedDataParallelBase):
self
.
_grad_buffers
=
None
self
.
_grad_buffers
=
None
# >>>
# >>>
from
collections
import
defaultdict
from
collections
import
defaultdict
self
.
_grad_buffer_param_offsets
=
None
# self._grad_buffer_param_offsets = None
self
.
_grad_buffer_param_index_map
=
None
# <<<
# <<<
if
self
.
use_contiguous_buffers
:
if
self
.
use_contiguous_buffers
:
self
.
_grad_buffers
=
{}
self
.
_grad_buffers
=
{}
self
.
_grad_buffer_param_offsets
=
defaultdict
(
dict
)
# >>>
# self._grad_buffer_param_offsets = defaultdict(dict)
# self._grad_buffer_param_index_map = defaultdict(dict)
self
.
_grad_buffer_param_index_map
=
{}
# <<<
# Simple function to define buffer type.
# Simple function to define buffer type.
def
_get_buffer_type
(
param
):
def
_get_buffer_type
(
param
):
...
@@ -154,8 +159,16 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -154,8 +159,16 @@ class DistributedDataParallel(DistributedDataParallelBase):
type_num_elements
[
dtype
]
-=
param
.
data
.
nelement
()
type_num_elements
[
dtype
]
-=
param
.
data
.
nelement
()
param
.
main_grad
=
self
.
_grad_buffers
[
dtype
].
get
(
param
.
main_grad
=
self
.
_grad_buffers
[
dtype
].
get
(
param
.
data
.
shape
,
type_num_elements
[
dtype
])
param
.
data
.
shape
,
type_num_elements
[
dtype
])
self
.
_grad_buffer_param_offsets
[
dtype
][
param
]
=
\
# >>>
type_num_elements
[
dtype
]
# self._grad_buffer_param_offsets[dtype][param] = \
# type_num_elements[dtype]
if
dtype
not
in
self
.
_grad_buffer_param_index_map
:
self
.
_grad_buffer_param_index_map
[
dtype
]
=
{}
self
.
_grad_buffer_param_index_map
[
dtype
][
param
]
=
{
"start"
:
type_num_elements
[
dtype
],
"end"
:
param
.
data
.
nelement
(),
}
# <<<
# Backward hook.
# Backward hook.
# Accumalation function for the gradients. We need
# Accumalation function for the gradients. We need
...
...
megatron/optimizer/optimizer.py
View file @
c5f93269
...
@@ -775,7 +775,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -775,7 +775,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {
# pax(0, {
# "model_param_group" : model_param_group,
# "model_param_group" : model_param_group,
# # "offset_map" : {str(p.shape):o for p, o in model_param_group["offset_map"].items()},
# "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()],
# "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()],
# })
# })
...
@@ -843,10 +842,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -843,10 +842,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(2, {
# pax(2, {
# "data_parallel_rank" : self.data_parallel_rank,
# "data_parallel_rank" : self.data_parallel_rank,
# "local_shard_info" : local_shard_info,
# "local_shard_info" : local_shard_info,
# "param_index_map " :
{
# "param_index_map " :
[
# str(p.shape)
:
i
#
(
str(p.shape)
,
i
)
# for p, i in local_shard_info["param_index_map"].items()
# for p, i in local_shard_info["param_index_map"].items()
#
}
,
#
]
,
# })
# })
# Allocate shards.
# Allocate shards.
...
@@ -904,15 +903,57 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -904,15 +903,57 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# timers = get_timers()
# timers = get_timers()
# <<<
# <<<
# >>> [ already checked in arguments.py ]
# >>> [
temporary requirement ... and
already checked in arguments.py ]
assert
args
.
use_contiguous_buffers_in_local_ddp
assert
args
.
use_contiguous_buffers_in_local_ddp
# <<<
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Map param to virtual model.
# ** ideally, this should happen once, during construction.
param_model_map
=
{}
for
vmodel
in
model
:
for
dtype
,
param_index_map
in
\
vmodel
.
_grad_buffer_param_index_map
.
items
():
for
param
in
param_index_map
:
param_model_map
[
param
]
=
{
"dtype"
:
dtype
,
"model"
:
vmodel
,
}
# pax(0, {
# "param_model_map" : [
# (str(tuple(p.shape)), m)
# for p, m in param_model_map.items()
# ],
# })
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard.
# Copy model grads to main shard.
local_shard_info_groups
=
[
g
[
self
.
data_parallel_rank
]
local_shard_info_groups
=
[
g
[
self
.
data_parallel_rank
]
for
g
in
self
.
world_shard_info_groups
]
for
g
in
self
.
world_shard_info_groups
]
for
group_index
,
local_shard_info
in
enumerate
(
local_shard_info_groups
):
# model_param_index_map =
shard_param_index_map
=
local_shard_info
[
"param_index_map"
]
for
param
,
shard_indexes
in
shard_param_index_map
.
items
():
dtype_model_dict
=
param_model_map
[
param
]
dtype
=
dtype_model_dict
[
"dtype"
]
vmodel
=
dtype_model_dict
[
"model"
]
grad_buffer_indexes
=
\
vmodel
.
_grad_buffer_param_index_map
[
dtype
][
param
]
pax
(
0
,
{
"dtype"
:
dtype
})
pax
(
0
,
{
"group_index"
:
group_index
,
"local_shard_info"
:
local_shard_info
,
"shard_param_index_map"
:
shard_param_index_map
,
"param"
:
tp
(
param
),
"shard_indexes"
:
shard_indexes
,
"grad_buffer_indexes"
:
grad_buffer_indexes
,
})
pax
(
0
,
{
pax
(
0
,
{
# "world_shard_info_groups" : self.world_shard_info_groups,
# "world_shard_info_groups" : self.world_shard_info_groups,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment