Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
cb6f96b6
"vscode:/vscode.git/clone" did not exist on "4f986fb28a30858f57904bb52c2f544125f9e766"
Commit
cb6f96b6
authored
Feb 15, 2022
by
Lawrence McAfee
Browse files
wip; switching to grad-buffer-centric design
parent
a3f3c3ad
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
351 additions
and
193 deletions
+351
-193
megatron/arguments.py
megatron/arguments.py
+5
-3
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+7
-6
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+334
-183
megatron/training.py
megatron/training.py
+5
-1
No files found.
megatron/arguments.py
View file @
cb6f96b6
...
...
@@ -130,9 +130,11 @@ def parse_args(extra_args_provider=None, defaults={},
args
.
global_batch_size
),
flush
=
True
)
assert
args
.
global_batch_size
>
0
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
assert
args
.
pipeline_model_parallel_size
>
2
,
\
'pipeline-model-parallel size should be greater than 2 with '
\
'interleaved schedule'
# >>> [ temporarily turning off ]
# assert args.pipeline_model_parallel_size > 2, \
# 'pipeline-model-parallel size should be greater than 2 with ' \
# 'interleaved schedule'
# <<<
assert
args
.
num_layers
%
args
.
num_layers_per_virtual_pipeline_stage
==
0
,
\
'number of layers is not divisible by number of layers per virtual '
\
'pipeline stage'
...
...
megatron/optimizer/__init__.py
View file @
cb6f96b6
...
...
@@ -97,11 +97,11 @@ def get_megatron_optimizer(model,
# from lutil import pax
# pax(0, {
# "model" : model,
# "param_groups" : param_groups,
# "param_groups / 0" : param_groups[0],
# "param_groups / 0 / params" : param_groups[0]["params"],
# "param_groups / 1" : param_groups[1],
# "param_groups / 1 / params" : param_groups[1]["params"],
#
#
"param_groups" : param_groups,
#
#
"param_groups / 0" : param_groups[0],
#
#
"param_groups / 0 / params" : param_groups[0]["params"],
#
#
"param_groups / 1" : param_groups[1],
#
#
"param_groups / 1 / params" : param_groups[1]["params"],
# })
# <<<
...
...
@@ -164,7 +164,8 @@ def get_megatron_optimizer(model,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
bf16
,
grad_scaler
)
grad_scaler
,
model
)
# <<<
# FP32.
...
...
megatron/optimizer/optimizer.py
View file @
cb6f96b6
...
...
@@ -184,12 +184,16 @@ class BaseFloat16Optimizer(MegatronOptimizer):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
):
bf16
,
grad_scaler
,
models
):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
)
# >>>
self
.
models
=
models
# <<<
self
.
bf16
=
bf16
self
.
grad_scaler
=
grad_scaler
# None grad scaler is only supported for bf16.
...
...
@@ -697,65 +701,338 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# def __init__(self, *_args):
# super().__init__(*_args)
# def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
# params_have_main_grad, use_contiguous_buffers_in_local_ddp,
# bf16, grad_scaler):
# super().__init__(
# optimizer, clip_grad, log_num_zeros_in_grad,
# params_have_main_grad, use_contiguous_buffers_in_local_ddp,
# bf16, grad_scaler)
# # >>>
# # self.test_reduce_scatter()
# # <<<
# # >>>
# args = get_args()
# # <<<
# # Data parallel info.
# self.data_parallel_group = mpu.get_data_parallel_group()
# self.data_parallel_rank = mpu.get_data_parallel_rank()
# self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Total trainable param count.
# # self.total_param_size = sum(
# # p.numel()
# # for g in self.param_groups
# # for p in g["params"]
# # # if p .requires_grad ???
# # )
# # Model params: group sizes, group offset maps.
# # self.model_params = []
# # self.model_param_group_sizes = []
# # self.model_param_group_offset_maps = []
# self.model_param_groups = []
# for param_group in self.optimizer.param_groups:
# param_group_offset = 0
# param_group_offset_map = {}
# for param in param_group['params']:
# if not param.requires_grad:
# continue
# # self.model_params.append(param)
# param_group_offset_map[param] = {
# "start" : param_group_offset,
# "end" : param_group_offset + param.numel(),
# }
# param_group_offset += param.numel()
# # self.model_param_group_sizes.append(param_group_offset)
# # self.model_param_group_offset_maps.append(param_group_offset_map)
# self.model_param_groups.append({
# "size" : param_group_offset,
# "offset_map" : param_group_offset_map,
# })
# # pax(0, {
# # "model_params" : model_params,
# # "model_param_group_sizes" : model_param_group_sizes,
# # "model_param_group_offset_maps" : model_param_group_offset_maps,
# # })
# # Shard allocator.
# # ** torch.nn.Parameter ??
# # ** MemoryBuffer ??
# allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# # Allocate shards.
# # (Also, collect world DP shard info.)
# # model_main_dtypes = set([ args.params_dtype, torch.float ])
# model_main_dtypes = set([ torch.float ]) # fp32 only, for now
# # self.world_shard_info_groups = [] # world_group_shard_infos ?
# # self.main_param_shard_groups = []
# self.world_shard_infos = [{"groups": []} for _ in self.model_param_groups]
# for group_index, model_param_group in enumerate(self.model_param_groups):
# # Max world shard size.
# model_param_size = model_param_group["size"]
# max_world_shard_size = int(math.ceil(model_param_size /
# self.data_parallel_world_size))
# # DP world shard infos.
# # world_shard_infos = []
# for r in range(self.data_parallel_world_size):
# shard_start_index = r * max_world_shard_size
# shard_end_index = min(model_param_size,
# shard_start_index + max_world_shard_size)
# # world_shard_infos.append({
# self.world_shard_infos[r]["groups"].append({
# "start" : shard_start_index,
# "end" : shard_end_index,
# "size" : shard_end_index - shard_start_index,
# })
# # self.world_shard_info_groups.append(world_shard_infos)
# # self.world_shard_infos[group_index].append(world_shard_infos)
# # DP local rank's shard info.
# # local_shard_info = world_shard_infos[self.data_parallel_rank]
# local_shard_info = \
# self.world_shard_infos[self.data_parallel_rank]["groups"][-1]
# local_shard_start_index = local_shard_info["start"]
# local_shard_end_index = local_shard_info["end"]
# local_shard_size = local_shard_info["size"]
# # Local shard's param 'slice' index map.
# local_shard_info["param_slice_index_map"] = {}
# for param, offset_dict in model_param_group["offset_map"].items():
# # param_start_index = offset_dict["start"]
# # param_end_index = offset_dict["end"]
# # param_shard_start_index = max(local_shard_start_index,
# # param_start_index)
# # param_shard_end_index = min(local_shard_end_index,
# # param_end_index)
# orig_start_index = offset_dict["start"]
# orig_end_index = offset_dict["end"]
# shard_start_index = max(
# 0,
# orig_start_index - local_shard_start_index)
# shard_end_index = min(
# local_shard_end_index,
# orig_end_index - local_shard_start_index)
# # if param_shard_end_index > param_shard_start_index:
# # # Indexes are relative to local shard start index.
# # # local_shard_info["param_index_map"][param] = {
# # # "param" : (
# # # param_shard_start_index,
# # # param_shard_end_index,
# # # ),
# # # "shard" : (
# # # param_shard_start_index - local_shard_start_index,
# # # param_shard_end_index - local_shard_start_index,
# # # ),
# # # }
# # local_shard_info["param_slice_index_map"][param] = {
# # "param_start" :
# # param_shard_start_index,
# # "shard_start" :
# # param_shard_start_index - local_shard_start_index,
# # "size":
# # param_shard_end_index - param_shard_start_index,
# # }
# if shard_end_index > shard_start_index:
# local_shard_info["param_slice_index_map"][param] = {
# "orig_start" : orig_start_index,
# "shard_start" : shard_start_index,
# "size" : shard_end_index - shard_start_index,
# }
# # pax(0, {
# # "local index" : "%d, %d" % (
# # local_shard_start_index,
# # local_shard_end_index,
# # ),
# # "param index" : "%s, %d" % (
# # param_start_index,
# # param_end_index,
# # ),
# # "param" : tp(param),
# # "shard_param_index_map" : shard_param_index_map,
# # "local_shard_info" : local_shard_info,
# # })
# # pax(2, {
# # "data_parallel_rank" : self.data_parallel_rank,
# # "local_shard_info" : local_shard_info,
# # "param_index_map " : [
# # (str(p.shape), i)
# # for p, i in local_shard_info["param_index_map"].items()
# # ],
# # })
# # Allocate shards.
# # (Non-fp32 shards are for convenience; e.g., intermediaries
# # between model params and main fp32 shard. Necessary???)
# # main_param_shards = {
# # ty : allocate_shard(local_shard_size, ty)
# # for ty in model_main_dtypes}
# main_param_shards = {}
# for dtype in model_main_dtypes:
# main_param = allocate_shard(local_shard_size, dtype)
# main_param.grad = allocate_shard(local_shard_size, dtype)
# # pax(0, {"main_param": main_param})
# main_param_shards[dtype] = main_param
# # self.main_param_shard_groups.append(main_param_shards)
# local_shard_info["data"] = main_param_shards
# # Update optimizer group.
# self.optimizer.param_groups[group_index]["params"] = \
# [ main_param_shards[torch.float] ]
# # pax(0, {
# # "param_groups" : self.optimizer.param_groups,
# # "params" : self.optimizer.param_groups[group_index]["params"],
# # })
# # Add world start/end indexes, for reduce/gather steps.
# offset = 0
# for r in self.world_shard_infos:
# r["start_index"] = offset
# offset += sum(g["size"] for g in r["groups"])
# r["end_index"] = offset
# # Leverage state_dict() and load_state_dict() to
# # recast preexisting per-param state tensors
# self.optimizer.load_state_dict(self.optimizer.state_dict())
# # >>>
# # pax(0, {
# # "world_shard_infos" : self.world_shard_infos,
# # **{
# # "world_shard_infos / %d" % i : r
# # for i, r in enumerate(self.world_shard_infos)
# # },
# # })
# # <<<
@
classmethod
# def get_ddp_gbuf_param_shards(cls, model, dtype, gbuf_start):
def
get_ddp_gbuf_param_shard_map
(
cls
,
model
,
dtype
,
gbuf_start
):
param_shard_map
=
{}
for
param
,
indexes
in
\
model
.
_grad_buffer_param_index_map
[
dtype
].
items
():
param_gbuf_start
,
param_gbuf_end
=
indexes
param_shard_start
=
max
(
0
,
param_gbuf_start
-
shard_start
)
param_shard_end
=
min
(
shard_end
,
param_gbuf_end
-
shard_start
)
if
param_shard_end
>
param_shard_start
:
dtype_info
[
"grad_buffer_param_shards"
][
param
]
=
{
"gbuf_start"
:
param_gbuf_start
,
"shard_start"
:
param_shard_start
,
"size"
:
param_shard_end
-
param_shard_start
,
}
# pax(0, {
# "param" : param,
# "indexes" : indexes,
# "param_gbuf_start" : param_gbuf_start,
# "param_gbuf_end" : param_gbuf_end,
# "param_shard_start" : param_shard_start,
# "param_shard_end" : param_shard_end,
# })
pax
(
0
,
{
"param_shard_map"
:
param_shard_map
})
return
param_shard_map
@
classmethod
def
get_ddp_gbuf_shard
(
cls
,
model
,
dtype
):
# Per-dtype info.
dtype_info
=
{}
model_info
[
dtype
]
=
dtype_info
# Grad buffer shard.
model_param_size
=
grad_buffer
.
numel
max_world_shard_size
=
int
(
math
.
ceil
(
model_param_size
/
self
.
data_parallel_world_size
))
shard_start
=
rank
*
max_world_shard_size
shard_end
=
min
(
model_param_size
,
shard_start
+
max_world_shard_size
)
dtype_info
[
"grad_buffer_shard"
]
=
{
"start"
:
shard_start
,
"end"
:
shard_end
,
"size"
:
shard_end
-
shard_start
,
}
# Grad buffer param shards.
dtype_info
[
"grad_buffer_param_shards"
]
=
self
.
get_ddp_gbuf_param_shards
()
pax
(
0
,
{
"grad_buffer_param_shards"
:
[
str
((
str
(
tuple
(
p
.
shape
)),
i
))
for
p
,
i
in
dtype_info
[
"grad_buffer_param_shards"
].
items
()
]})
return
ddp_gbuf_shard
@
classmethod
# def get_ddp_gbuf_shards(cls, model):
def
get_ddp_gbuf_shard_map
(
cls
,
model
):
shard_map
=
{
dtype
:
cls
.
get_ddp_gbuf_shard
(
model
,
dtype
)
for
dtype
in
model
.
_grad_buffers
}
pax
(
0
,
{
"shard_map"
:
shard_map
})
return
shard_map
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
):
bf16
,
grad_scaler
,
models
):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
)
# >>>
# self.test_reduce_scatter()
# <<<
bf16
,
grad_scaler
,
models
)
# >>>
args
=
get_args
()
assert
args
.
use_contiguous_buffers_in_local_ddp
# already checked in args
# <<<
# pax(0, {"models": models})
# Data parallel info.
self
.
data_parallel_group
=
mpu
.
get_data_parallel_group
()
self
.
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
self
.
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
# Total trainable param count.
# self.total_param_size = sum(
# p.numel()
# for g in self.param_groups
# for p in g["params"]
# # if p .requires_grad ???
# )
# Model params: group sizes, group offset maps.
# self.model_params = []
# self.model_param_group_sizes = []
# self.model_param_group_offset_maps = []
self
.
model_param_groups
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
param_group_offset
=
0
param_group_offset_map
=
{}
for
param
in
param_group
[
'params'
]:
if
not
param
.
requires_grad
:
continue
# self.model_params.append(param)
param_group_offset_map
[
param
]
=
{
"start"
:
param_group_offset
,
"end"
:
param_group_offset
+
param
.
numel
(),
}
param_group_offset
+=
param
.
numel
()
# self.model_param_group_sizes.append(param_group_offset)
# self.model_param_group_offset_maps.append(param_group_offset_map)
self
.
model_param_groups
.
append
({
"size"
:
param_group_offset
,
"offset_map"
:
param_group_offset_map
,
})
# Param group map.
self
.
param_group_map
=
{}
for
group_index
,
group
in
enumerate
(
self
.
optimizer
.
param_groups
):
for
param
in
group
[
"params"
]:
assert
param
.
requires_grad
self
.
param_group_map
[
param
]
=
group_index
# pax(0, {
# "model_params" : model_params,
# "model_param_group_sizes" : model_param_group_sizes,
# "model_param_group_offset_maps" : model_param_group_offset_maps,
# })
# pax(0, {"param_group_map": [
# (g, str(p.shape))
# for p, g in self.param_group_map.items()
# ]})
# Shard allocator.
# ** torch.nn.Parameter ??
...
...
@@ -766,154 +1043,28 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
True
)
# Allocate shards.
# (Also, collect world DP shard info.)
# model_main_dtypes = set([ args.params_dtype, torch.float ])
model_main_dtypes
=
set
([
torch
.
float
])
# fp32 only, for now
# self.world_shard_info_groups = [] # world_group_shard_infos ?
# self.main_param_shard_groups = []
self
.
world_shard_infos
=
[{
"groups"
:
[]}
for
_
in
self
.
model_param_groups
]
for
group_index
,
model_param_group
in
enumerate
(
self
.
model_param_groups
):
# Max world shard size.
model_param_size
=
model_param_group
[
"size"
]
max_world_shard_size
=
int
(
math
.
ceil
(
model_param_size
/
self
.
data_parallel_world_size
))
# DP world shard infos.
# world_shard_infos = []
for
r
in
range
(
self
.
data_parallel_world_size
):
shard_start_index
=
r
*
max_world_shard_size
shard_end_index
=
min
(
model_param_size
,
shard_start_index
+
max_world_shard_size
)
# world_shard_infos.append({
self
.
world_shard_infos
[
r
][
"groups"
].
append
({
"start"
:
shard_start_index
,
"end"
:
shard_end_index
,
"size"
:
shard_end_index
-
shard_start_index
,
})
# self.world_shard_info_groups.append(world_shard_infos)
# self.world_shard_infos[group_index].append(world_shard_infos)
# DP local rank's shard info.
# local_shard_info = world_shard_infos[self.data_parallel_rank]
local_shard_info
=
\
self
.
world_shard_infos
[
self
.
data_parallel_rank
][
"groups"
][
-
1
]
local_shard_start_index
=
local_shard_info
[
"start"
]
local_shard_end_index
=
local_shard_info
[
"end"
]
local_shard_size
=
local_shard_info
[
"size"
]
# Local shard's param 'slice' index map.
local_shard_info
[
"param_slice_index_map"
]
=
{}
for
param
,
offset_dict
in
model_param_group
[
"offset_map"
].
items
():
# param_start_index = offset_dict["start"]
# param_end_index = offset_dict["end"]
# param_shard_start_index = max(local_shard_start_index,
# param_start_index)
# param_shard_end_index = min(local_shard_end_index,
# param_end_index)
orig_start_index
=
offset_dict
[
"start"
]
orig_end_index
=
offset_dict
[
"end"
]
shard_start_index
=
max
(
0
,
orig_start_index
-
local_shard_start_index
)
shard_end_index
=
min
(
local_shard_end_index
,
orig_end_index
-
local_shard_start_index
)
# if param_shard_end_index > param_shard_start_index:
# # Indexes are relative to local shard start index.
# # local_shard_info["param_index_map"][param] = {
# # "param" : (
# # param_shard_start_index,
# # param_shard_end_index,
# # ),
# # "shard" : (
# # param_shard_start_index - local_shard_start_index,
# # param_shard_end_index - local_shard_start_index,
# # ),
# # }
# local_shard_info["param_slice_index_map"][param] = {
# "param_start" :
# param_shard_start_index,
# "shard_start" :
# param_shard_start_index - local_shard_start_index,
# "size":
# param_shard_end_index - param_shard_start_index,
# }
if
shard_end_index
>
shard_start_index
:
local_shard_info
[
"param_slice_index_map"
][
param
]
=
{
"orig_start"
:
orig_start_index
,
"shard_start"
:
shard_start_index
,
"size"
:
shard_end_index
-
shard_start_index
,
}
# pax(0, {
# "local index" : "%d, %d" % (
# local_shard_start_index,
# local_shard_end_index,
# ),
# "param index" : "%s, %d" % (
# param_start_index,
# param_end_index,
# ),
# "param" : tp(param),
# "shard_param_index_map" : shard_param_index_map,
# "local_shard_info" : local_shard_info,
# })
# pax(2, {
# "data_parallel_rank" : self.data_parallel_rank,
# "local_shard_info" : local_shard_info,
# "param_index_map " : [
# (str(p.shape), i)
# for p, i in local_shard_info["param_index_map"].items()
# ],
# })
# Allocate shards.
# (Non-fp32 shards are for convenience; e.g., intermediaries
# between model params and main fp32 shard. Necessary???)
# main_param_shards = {
# ty : allocate_shard(local_shard_size, ty)
# for ty in model_main_dtypes}
main_param_shards
=
{}
for
dtype
in
model_main_dtypes
:
main_param
=
allocate_shard
(
local_shard_size
,
dtype
)
main_param
.
grad
=
allocate_shard
(
local_shard_size
,
dtype
)
# pax(0, {"main_param": main_param})
main_param_shards
[
dtype
]
=
main_param
# self.main_param_shard_groups.append(main_param_shards)
local_shard_info
[
"data"
]
=
main_param_shards
# Update optimizer group.
self
.
optimizer
.
param_groups
[
group_index
][
"params"
]
=
\
[
main_param_shards
[
torch
.
float
]
]
# World shard infos.
self
.
world_shard_infos
=
[]
for
rank
in
range
(
self
.
data_parallel_world_size
):
#
pax(0, {
# "param_groups" : self.optimizer.param_groups,
# "params" : self.optimizer.param_groups[group_index]["params"],
# })
#
Per-rank info.
rank_info
=
[]
self
.
world_shard_infos
.
append
(
rank_info
)
for
model_index
,
model
in
enumerate
(
self
.
models
):
# Add world start/end indexes, for reduce/gather steps.
offset
=
0
for
r
in
self
.
world_shard_infos
:
r
[
"start_index"
]
=
offset
offset
+=
sum
(
g
[
"size"
]
for
g
in
r
[
"groups"
])
r
[
"end_index"
]
=
offset
# Per-virtual-model info.
# model_info = {}
# rank_info.append(model_info)
ddp_gbuf_shards
=
self
.
get_ddp_gbuf_shards
(
model
)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# >>>
# pax(0, {
# "world_shard_infos" : self.world_shard_infos,
# **{
# "world_shard_infos / %d" % i : r
# for i, r in enumerate(self.world_shard_infos)
# },
# })
pax
(
0
,
{
"world_shard_infos"
:
self
.
world_shard_infos
,
})
# <<<
# def get_loss_scale(self):
...
...
megatron/training.py
View file @
cb6f96b6
...
...
@@ -365,8 +365,12 @@ def setup_model_and_optimizer(model_provider_func,
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
,
no_wd_decay_cond
,
# >>>
# optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
# scale_lr_cond, lr_mult)
optimizer
=
get_megatron_optimizer
(
model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
# <<<
opt_param_scheduler
=
get_optimizer_param_scheduler
(
optimizer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment