Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f48e1f29
Commit
f48e1f29
authored
Feb 11, 2022
by
Lawrence McAfee
Browse files
studied float16 optimizer; more updates
parent
49cca4d9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
192 additions
and
59 deletions
+192
-59
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+192
-59
No files found.
megatron/optimizer/optimizer.py
View file @
f48e1f29
...
@@ -643,102 +643,235 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
...
@@ -643,102 +643,235 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# >>>
# >>>
class
Float16DistributedOptimizer
(
Float16OptimizerWithFloat16Params
):
import
math
def
__init__
(
self
,
*
args
):
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
super
().
__init__
(
*
args
)
self
.
initialized
=
False
# >>>
self
.
initialize
()
# <<<
def
initialize
(
self
):
from
megatron
import
get_args
# from megatron import get_timers
# from megatron.model import DistributedDataParallel as LocalDDP
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
# >>>
# >>>
import
math
from
lutil
import
pax
,
tp
# <<<
# <<<
if
self
.
initialized
:
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
raise
Exception
(
"initialization worked."
)
class
Float16DistributedOptimizer
(
MegatronOptimizer
):
return
self
.
initialized
=
True
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
# >>>
@
classmethod
def
test_reduce_scatter
(
cls
):
torch
.
manual_seed
(
mpu
.
get_data_parallel_rank
())
size
=
(
20
,)
dtype
=
torch
.
float
device
=
torch
.
cuda
.
current_device
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
total_param_size
=
sum
(
data_parallel_group
=
mpu
.
get_data_parallel_group
()
p
.
numel
()
for
g
in
self
.
param_groups
input_list
=
[
for
p
in
g
[
"params"
]
# torch.randn(size, dtype = dtype, device = device)
5
*
torch
.
randint
(
low
=
1
,
high
=
3
,
size
=
size
,
dtype
=
dtype
,
device
=
device
)
for
_
in
range
(
data_parallel_world_size
)
]
output
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
device
)
torch
.
distributed
.
reduce_scatter
(
output
,
input_list
,
group
=
data_parallel_group
,
)
)
shard_size
=
int
(
math
.
ceil
(
total_param_size
/
data_parallel_world_size
))
shard_start_index
=
data_parallel_rank
*
shard_size
shard_end_index
=
min
(
total_param_size
,
shard_start_index
+
shard_size
)
self
.
shard_size
=
shard_end_index
-
shard_start_index
# allocate_shard = lambda dtype : torch.empty(
# [self.shard_size],
# dtype = dtype,
# device = torch.cuda.current_device())
allocate_shard
=
lambda
dtype
:
MemoryBuffer
(
self
.
shard_size
,
dtype
)
self
.
main_param_shard
=
allocate_shard
(
torch
.
float
)
self
.
main_grad_shard
=
allocate_shard
(
torch
.
float
)
self
.
adam_m_shard
=
allocate_shard
(
torch
.
float
)
self
.
adam_v_shard
=
allocate_shard
(
torch
.
float
)
def
reduce_gradients
(
self
,
model
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
output
)
pax
(
0
,
{
"data_parallel_world_size"
:
data_parallel_world_size
,
"data_parallel_group"
:
data_parallel_group
,
"input_list"
:
input_list
,
"output"
:
tp
(
output
),
})
# <<<
# def __init__(self, *_args):
# super().__init__(*_args)
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
)
# >>>
# >>>
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
# self.test_reduce_scatter()
# <<<
from
megatron
import
get_args
# >>>
# from megatron import get_timers
args
=
get_args
()
# from megatron.model import DistributedDataParallel as LocalDDP
# <<<
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
# Data parallel info.
self
.
data_parallel_group
=
mpu
.
get_data_parallel_group
()
self
.
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
self
.
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
# Total trainable param count.
# self.total_param_size = sum(
# p.numel()
# for g in self.param_groups
# for p in g["params"]
# # if p .requires_grad ???
# )
# Model params: group sizes, group offset maps.
# self.model_params = []
# self.model_param_group_sizes = []
# self.model_param_group_offset_maps = []
self
.
model_param_groups
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
param_group_offset
=
0
param_group_offset_map
=
{}
for
param
in
param_group
[
'params'
]:
if
not
param
.
requires_grad
:
continue
# self.model_params.append(param)
param_group_offset_map
[
param
]
=
{
"start"
:
param_group_offset
,
"end"
:
param_group_offset
+
param
.
numel
(),
}
param_group_offset
+=
param
.
numel
()
# self.model_param_group_sizes.append(param_group_offset)
# self.model_param_group_offset_maps.append(param_group_offset_map)
self
.
model_param_groups
.
append
({
"size"
:
param_group_offset
,
"offset_map"
:
param_group_offset_map
,
})
# pax(0, {
# "model_params" : model_params,
# "model_param_group_sizes" : model_param_group_sizes,
# "model_param_group_offset_maps" : model_param_group_offset_maps,
# })
# Shard allocator.
allocate_shard
=
lambda
shard_size
,
dtype
:
torch
.
empty
(
(
shard_size
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
# allocate_shard = lambda dtype : MemoryBuffer(self.shard_size, dtype)
# Collect DP world shard infos, per group.
model_main_dtypes
=
set
([
args
.
params_dtype
,
torch
.
float
])
self
.
world_shard_info_groups
=
[]
# world_group_shard_infos ?
self
.
main_param_shard_groups
=
[]
for
model_param_group_size
in
model_param_group_sizes
:
max_world_shard_size
=
int
(
math
.
ceil
(
model_param_group_size
/
self
.
data_parallel_world_size
))
# Group shard infos.
shard_infos
=
[]
for
r
in
range
(
self
.
data_parallel_world_size
):
shard_start_index
=
r
*
max_shard_size
shard_end_index
=
min
(
self
.
total_param_size
,
shard_start_index
+
max_shard_size
)
shard_infos
.
append
({
"start"
:
shard_start_index
,
"end"
:
shard_end_index
,
"size"
:
shard_end_index
-
shard_start_index
,
})
self
.
world_shard_info_groups
.
append
(
shard_infos
)
# Allocate shards.
local_shard_size
=
\
self
.
world_shard_infos
[
self
.
data_parallel_rank
][
"size"
]
# # self.main_param_shard = allocate_shard(torch.float)
# # self.main_grad_shard = allocate_shard(torch.float)
# self.param_shard_map = {ty:allocate_shard(ty) for ty in dtypes}
# self.grad_shard_map = {ty:allocate_shard(ty) for ty in dtypes}
# self.adam_m_shard = allocate_shard(torch.float)
# self.adam_v_shard = allocate_shard(torch.float)
self
.
main_param_shard_groups
.
append
({
ty
:
allocate_shard
(
ty
)
for
ty
in
model_main_dtypes
})
# >>>
# pax(0, {
# "total_param_size" : self.total_param_size,
# "max_shard_size" : max_shard_size,
# "shard_infos" : self.shard_infos,
# "shard_size" : shard_size,
# "param_shard_map" : self.param_shard_map,
# })
# <<<
def
get_loss_scale
(
self
):
raise
Exception
(
"hi."
)
def
load_state_dict
(
self
):
raise
Exception
(
"hi."
)
def
reload_model_params
(
self
):
raise
Exception
(
"hi."
)
def
state_dict
(
self
):
raise
Exception
(
"hi."
)
def
zero_grad
(
self
):
raise
Exception
(
"hi."
)
def
reduce_gradients
(
self
,
model
):
# >>>
args
=
get_args
()
args
=
get_args
()
# timers = get_timers()
# timers = get_timers()
# <<<
# <<<
# >>>
# >>>
[ already checked in arguments.py ]
assert
args
.
use_contiguous_buffers_in_local_ddp
assert
args
.
use_contiguous_buffers_in_local_ddp
# <<<
# <<<
# grad_buffers = [ m._grad_buffers for m in model ]
# grad_buffers = [ m._grad_buffers for m in model ]
for
virtual_model
in
model
:
for
virtual_model
in
model
:
grad_buffers
=
virtual_model
.
_grad_buffers
grad_buffer_map
=
virtual_model
.
_grad_buffers
# >>>
assert
len
(
grad_buffer_map
)
==
1
,
\
"multiple param types not currently supported."
assert
args
.
params_dtype
in
grad_buffer_map
assert
self
.
total_param_size
==
grad_buffer_map
[
args
.
params_dtype
].
numel
# <<<
# pax(0, {
# "total_param_size" : self.total_param_size,
# "grad_buffer" : tp(grad_buffer_map[args.params_dtype]),
# })
for
dtype
,
grad_buffer
in
grad_buffer
s
.
items
():
for
dtype
,
grad_buffer
in
grad_buffer
_map
.
items
():
dp_grad_buffers
=
[
dp_grad_buffers
=
[
grad_buffer
.
get
(
self
.
shard_sizes
[
i
],
grad_buffer
.
get
(
torch
.
Size
((
self
.
shard_infos
[
i
][
"size"
],)),
self
.
shard_start_indexes
[
i
])
self
.
shard_infos
[
i
][
"start"
])
for
i
in
self
.
data_parallel_world_size
]
for
i
in
range
(
self
.
data_parallel_world_size
)]
grad_shard
=
self
.
grad_shard_map
[
dtype
]
pax
(
0
,
{
"dp_grad_buffers"
:
dp_grad_buffers
})
torch
.
distributed
.
reduce_scatter
(
torch
.
distributed
.
reduce_scatter
(
self
.
main_
grad_shard
,
grad_shard
,
grad_buffer
.
data
,
dp_
grad_buffer
s
,
group
=
mpu
.
get_
data_parallel_group
()
,
group
=
self
.
data_parallel_group
,
)
)
# >>>
# >>>
pax
(
0
,
{
pax
(
0
,
{
"virtual_model"
:
virtual_model
,
"virtual_model"
:
virtual_model
,
"grad_buffer
s
"
:
grad_buffer
s
,
"grad_buffer
_map
"
:
grad_buffer
_map
,
"dtype"
:
dtype
,
"dtype"
:
dtype
,
"grad_buffer / len"
:
grad_buffer
.
numel
,
"grad_shard"
:
tp
(
grad_shard
),
"grad_buffer / data"
:
tp
(
grad_buffer
.
data
),
"dp_grad_buffers"
:
dp_grad_buffers
,
# "optimizer" : self.optimizer,
"main_grad_shard"
:
tp
(
self
.
main_grad_shard
),
})
})
# <<<
# <<<
# >>>
# >>>
from
lutil
import
pax
,
tp
pax
(
0
,
{
pax
(
0
,
{
"model"
:
model
,
"model"
:
model
,
"grad_buffers"
:
grad_buffers
,
"grad_buffers"
:
grad_buffers
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment