Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f498a6fe
Commit
f498a6fe
authored
Feb 15, 2022
by
Lawrence McAfee
Browse files
modularized shard indexing
parent
cb6f96b6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
90 additions
and
76 deletions
+90
-76
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+90
-76
No files found.
megatron/optimizer/optimizer.py
View file @
f498a6fe
...
@@ -661,6 +661,17 @@ from megatron import get_args
...
@@ -661,6 +661,17 @@ from megatron import get_args
from
lutil
import
pax
,
tp
from
lutil
import
pax
,
tp
# <<<
# <<<
# class ShardIndex:
class
Shard
:
def
__init__
(
self
,
start
,
end
):
self
.
start
=
start
self
.
end
=
end
self
.
size
=
end
-
start
def
normalize
(
self
,
start
=
0
):
return
Shard
(
start
,
start
+
self
.
size
)
def
__str__
(
self
):
return
"%d,%d [%d]"
%
(
self
.
start
,
self
.
end
,
self
.
size
)
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(MegatronOptimizer):
# class Float16DistributedOptimizer(MegatronOptimizer):
class
Float16DistributedOptimizer
(
BaseFloat16Optimizer
):
class
Float16DistributedOptimizer
(
BaseFloat16Optimizer
):
...
@@ -921,83 +932,87 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -921,83 +932,87 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# # <<<
# # <<<
@
classmethod
@
classmethod
# def get_ddp_gbuf_param_shards(cls, model, dtype, gbuf_start):
# def get_ddp_gbuf_param_shards(cls, model, dtype, gbuf_start):
def
get_ddp_gbuf_param_shard_map
(
cls
,
model
,
dtype
,
gbuf_start
):
# def get_ddp_gbuf_param_shard_map(cls, model, dtype, gbuf_start):
# def get_model_gbuf_param_shard_index_map(cls,model,dtype,gbuf_world_index):
def
get_model_gbuf_param_shard_map
(
cls
,
model
,
dtype
,
gbuf_world_shard
):
# Param shard map.
param_world_index_map
=
model
.
_grad_buffer_param_index_map
[
dtype
]
param_shard_map
=
{}
param_shard_map
=
{}
for
param
,
indexes
in
\
for
param
,
param_world_indexes
in
param_world_index_map
.
items
():
model
.
_grad_buffer_param_index_map
[
dtype
].
items
():
param_gbuf_start
,
param_gbuf_end
=
indexes
# Shard range.
param_shard_start
=
max
(
param_world_start
,
param_world_end
=
param_world_indexes
param_local_start
=
max
(
0
,
0
,
param_gbuf_start
-
shard_start
)
param_world_start
-
gbuf_world_shard
.
start
)
param_shard_end
=
min
(
param_local_end
=
min
(
shard_end
,
gbuf_world_shard
.
size
,
param_gbuf_end
-
shard_start
)
param_world_end
-
gbuf_world_shard
.
start
)
if
param_shard_end
>
param_shard_start
:
# Add shard, if within range.
dtype_info
[
"grad_buffer_param_shards"
][
param
]
=
{
if
param_local_end
>
param_local_start
:
"gbuf_start"
:
param_gbuf_start
,
param_local_shard
=
Shard
(
param_local_start
,
param_local_end
)
"shard_start"
:
param_shard_start
,
param_world_shard
=
param_local_shard
.
normalize
(
param_world_start
)
"size"
:
param_shard_end
-
param_shard_start
,
param_shard_map
[
param
]
=
{
"local"
:
param_local_shard
,
"world"
:
param_world_shard
,
}
}
# pax(0, {
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
# "param" : param,
# "indexes" : indexes,
# "param_gbuf_start" : param_gbuf_start,
# "param_gbuf_end" : param_gbuf_end,
# "param_shard_start" : param_shard_start,
# "param_shard_end" : param_shard_end,
# })
pax
(
0
,
{
"param_shard_map"
:
param_shard_map
})
return
param_shard_map
return
param_shard_map
@
classmethod
@
classmethod
def
get_ddp_gbuf_shard
(
cls
,
model
,
dtype
):
# def get_ddp_gbuf_shard(cls, model, dtype):
# def get_model_gbuf_shard(cls, model, dtype):
# def get_model_gbuf_shard_index(cls, model, dtype):
def
get_model_gbuf_shard
(
cls
,
model
,
dtype
):
# Per-dtype info.
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
dtype_info
=
{}
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
model_info
[
dtype
]
=
dtype_info
# Grad buffer shard.
# Grad buffer shard.
model_param_size
=
grad_buffer
.
numel
grad_buffer
=
model
.
_grad_buffers
[
dtype
]
max_world_shard_size
=
int
(
math
.
ceil
(
gbuf_size
=
grad_buffer
.
numel
model_param_size
/
self
.
data_parallel_world_size
))
max_gbuf_shard_size
=
int
(
math
.
ceil
(
gbuf_size
/
data_parallel_world_size
))
shard_start
=
rank
*
max_world_shard_size
gbuf_world_start
=
data_parallel_rank
*
max_gbuf_shard_size
shard_end
=
min
(
model_param_size
,
gbuf_world_end
=
min
(
gbuf_size
,
gbuf_world_start
+
max_gbuf_shard_size
)
shard_start
+
max_world_shard_size
)
gbuf_world_shard
=
Shard
(
gbuf_world_start
,
gbuf_world_end
)
gbuf_local_shard
=
gbuf_world_shard
.
normalize
()
dtype_info
[
"grad_buffer_shard"
]
=
{
# gbuf_local_shard = Shard(0, gbuf_world_index.size)
"start"
:
shard_start
,
"end"
:
shard_end
,
# Param shards.
"size"
:
shard_end
-
shard_start
,
param_shard_map
=
cls
.
get_model_gbuf_param_shard_map
(
model
,
dtype
,
gbuf_world_shard
)
# Altogether.
data
=
{
"local"
:
gbuf_local_shard
,
"world"
:
gbuf_world_shard
,
"param_map"
:
param_shard_map
,
}
}
# Grad buffer param shards.
# pax(0, {"data": data})
dtype_info
[
"grad_buffer_param_shards"
]
=
self
.
get_ddp_gbuf_param_shards
()
pax
(
0
,
{
"grad_buffer_param_shards"
:
[
return
data
str
((
str
(
tuple
(
p
.
shape
)),
i
))
for
p
,
i
in
dtype_info
[
"grad_buffer_param_shards"
].
items
()
]})
return
ddp_gbuf_shard
@
classmethod
@
classmethod
# def get_ddp_gbuf_shards(cls, model):
# def get_ddp_gbuf_shards(cls, model):
def
get_ddp_gbuf_shard_map
(
cls
,
model
):
# def get_ddp_gbuf_shard_map(cls, model):
# def get_model_gbuf_shard_map(cls, model):
# def get_model_gbuf_shard_index_map(cls, model):
def
get_model_gbuf_shard_map
(
cls
,
model
):
# shard_index_map = {
shard_map
=
{
shard_map
=
{
dtype
:
cls
.
get_
ddp
_gbuf_shard
(
model
,
dtype
)
dtype
:
cls
.
get_
model
_gbuf_shard
(
model
,
dtype
)
for
dtype
in
model
.
_grad_buffers
for
dtype
in
model
.
_grad_buffers
}
}
pax
(
0
,
{
"shard_map"
:
shard_map
})
#
pax(0, {"shard_map": shard_map})
return
shard_map
return
shard_map
...
@@ -1017,10 +1032,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1017,10 +1032,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {"models": models})
# pax(0, {"models": models})
# Data parallel info.
#
#
Data parallel info.
self
.
data_parallel_group
=
mpu
.
get_data_parallel_group
()
#
self.data_parallel_group = mpu.get_data_parallel_group()
self
.
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
#
self.data_parallel_rank = mpu.get_data_parallel_rank()
self
.
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
#
self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# Param group map.
# Param group map.
self
.
param_group_map
=
{}
self
.
param_group_map
=
{}
...
@@ -1037,25 +1052,24 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1037,25 +1052,24 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Shard allocator.
# Shard allocator.
# ** torch.nn.Parameter ??
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
# ** MemoryBuffer ??
allocate_shard
=
lambda
shard_size
,
dtype
:
torch
.
empty
(
# allocate_shard = lambda shard_size, dtype : torch.empty(
(
shard_size
,),
# (shard_size,),
dtype
=
dtype
,
# dtype = dtype,
device
=
torch
.
cuda
.
current_device
(),
# device = torch.cuda.current_device(),
requires_grad
=
True
)
# requires_grad = True)
# World shard infos.
# Model grad buffer shards.
self
.
world_shard_infos
=
[]
self
.
model_gbuf_shards
=
[]
for
rank
in
range
(
self
.
data_parallel_world_size
):
for
model_index
,
model
in
enumerate
(
self
.
models
):
self
.
model_gbuf_shards
.
append
(
self
.
get_model_gbuf_shard_map
(
model
))
# Per-rank info.
rank_info
=
[]
# Allocate main param/grad shard.
self
.
world_shard_infos
.
append
(
rank_info
)
param_shard_map
=
self
.
get_param_shard_map
(
self
.
model_gbuf_shards
)
for
model_index
,
model
in
enumerate
(
self
.
models
):
pax
(
0
,
{
# Per-virtual-model info.
"model_gbuf_shards"
:
self
.
model_gbuf_shards
,
# model_info = {}
"param_shard_map"
:
param_shard_map
,
# rank_info.append(model_info)
})
ddp_gbuf_shards
=
self
.
get_ddp_gbuf_shards
(
model
)
# Leverage state_dict() and load_state_dict() to
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# recast preexisting per-param state tensors
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment