Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
11581195
Commit
11581195
authored
Mar 23, 2022
by
Lawrence McAfee
Browse files
removed mem savings factor. +cleanup.
parent
41038d54
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
241 deletions
+15
-241
megatron/arguments.py
megatron/arguments.py
+0
-6
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+1
-35
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+9
-169
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+5
-31
No files found.
megatron/arguments.py
View file @
11581195
...
@@ -733,12 +733,6 @@ def _add_distributed_args(parser):
...
@@ -733,12 +733,6 @@ def _add_distributed_args(parser):
'affects the encoder embedding.)'
)
'affects the encoder embedding.)'
)
group
.
add_argument
(
'--use-distributed-optimizer'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-distributed-optimizer'
,
action
=
'store_true'
,
help
=
'Use distributed optimizer.'
)
help
=
'Use distributed optimizer.'
)
group
.
add_argument
(
'--distrib-opt-comm-mem-savings'
,
default
=
0.
,
type
=
float
,
help
=
'Trade-off memory savings & iteration time, for '
'disributed optimizer
\'
s communication operations (i.e., '
'(reduce/gather). This value ranges from 0.0 (default, '
'no memory savings) to 1.0 (max memory savings, at the '
'expense of iteration time).'
)
return
parser
return
parser
...
...
megatron/optimizer/clip_grads.py
View file @
11581195
...
@@ -51,31 +51,8 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
...
@@ -51,31 +51,8 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
if
isinstance
(
parameters
,
torch
.
Tensor
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
parameters
=
[
parameters
]
# >>>
# # Filter parameters based on:
# # - grad should not be none
# # - parameter should not be shared
# # - should not be a replica due to tensor model parallelism
# grads = []
# grads_for_norm = []
# for param in parameters:
# grad_not_none = param.grad is not None
# is_not_shared = param_is_not_shared(param)
# is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
# if grad_not_none:
# grad = param.grad.detach()
# if grad_not_none:
# # Make sure the grads are in fp32
# assert param.grad.type() == 'torch.cuda.FloatTensor'
# grads.append(grad)
# if grad_not_none and is_not_shared and is_not_tp_duplicate:
# grads_for_norm.append(grad)
# <<<
# >>>
# Grads.
# Grads.
grads
=
[
p
.
grad
for
p
in
parameters
if
p
is
not
None
]
grads
=
[
p
.
grad
.
detach
()
for
p
in
parameters
if
p
.
grad
is
not
None
]
# <<<
# Norm parameters.
# Norm parameters.
max_norm
=
float
(
max_norm
)
max_norm
=
float
(
max_norm
)
...
@@ -119,17 +96,6 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
...
@@ -119,17 +96,6 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
group
=
model_parallel_group
)
group
=
model_parallel_group
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
# >>>
# from megatron import get_args
# from lutil import pax
# args = get_args()
# pax(0, {
# "use distrib opt" : args.use_distributed_optimizer,
# "norm_type" : norm_type,
# "total_norm" : total_norm,
# })
# <<<
# Scale.
# Scale.
clip_coeff
=
max_norm
/
(
total_norm
+
1.0e-6
)
clip_coeff
=
max_norm
/
(
total_norm
+
1.0e-6
)
if
clip_coeff
<
1.0
:
if
clip_coeff
<
1.0
:
...
...
megatron/optimizer/distrib_optimizer.py
View file @
11581195
...
@@ -22,17 +22,11 @@ import torch
...
@@ -22,17 +22,11 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
# >>>
from
megatron.model.module
import
param_is_not_shared
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
# <<<
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
# >>>
from
.optimizer
import
get_clippy
from
lutil
import
pax
,
tp
# <<<
class
Shard
:
class
Shard
:
def
__init__
(
self
,
start
,
end
):
def
__init__
(
self
,
start
,
end
):
...
@@ -196,12 +190,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -196,12 +190,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Update group's param.
# Update group's param.
group_shard
[
"orig_group"
][
"params"
]
=
[
main_param
]
group_shard
[
"orig_group"
][
"params"
]
=
[
main_param
]
# >>>
@
classmethod
@
classmethod
def
get_main_grad_views_for_grad_norm
(
cls
,
opt_group_shards
,
optimizer
):
def
get_main_grad_views_for_grad_norm
(
cls
,
opt_group_shards
,
optimizer
):
grad_views
=
[]
grad_views
=
[]
# grad_views_SKIPPED = []
for
group_index
,
opt_group_shard
in
enumerate
(
opt_group_shards
):
for
group_index
,
opt_group_shard
in
enumerate
(
opt_group_shards
):
opt_grad
=
optimizer
.
param_groups
[
group_index
][
"params"
][
0
].
grad
opt_grad
=
optimizer
.
param_groups
[
group_index
][
"params"
][
0
].
grad
for
param
,
shard
in
opt_group_shard
[
"param_map"
].
items
():
for
param
,
shard
in
opt_group_shard
[
"param_map"
].
items
():
...
@@ -211,30 +203,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -211,30 +203,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
grad_view
=
opt_grad
[
shard
.
start
:
shard
.
end
]
grad_view
=
opt_grad
[
shard
.
start
:
shard
.
end
]
grad_views
.
append
(
grad_view
)
grad_views
.
append
(
grad_view
)
# else:
# grad_views_SKIPPED.append(opt_grad[shard.start:shard.end])
# >>>
# my_rank = torch.distributed.get_rank()
# for r in range(torch.distributed.get_world_size()):
# if r == my_rank:
# print("r %d, grad views %s." % (
# my_rank,
# ", ".join(str(tuple(g.shape)) for g in grad_views),
# ))
# torch.distributed.barrier()
# for r in range(torch.distributed.get_world_size()):
# if r == my_rank:
# print("r %d, SKIPPED %s." % (
# my_rank,
# ", ".join(str(tuple(g.shape)) for g in grad_views_SKIPPED),
# ))
# torch.distributed.barrier()
# exit(0)
# <<<
return
grad_views
return
grad_views
# <<<
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
...
@@ -274,16 +243,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -274,16 +243,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Initialize main params.
# Initialize main params.
self
.
_copy_model_params_to_main_params
()
self
.
_copy_model_params_to_main_params
()
# >>> numel/nelem per rank >>>
# for r in range(torch.distributed.get_world_size()):
# if r == torch.distributed.get_rank():
# for m in self.models:
# for b in m._grad_buffers.values():
# print("r %d, %d." % (r, b.data.nelement()))
# torch.distributed.barrier()
# exit(0)
# <<<
# Params for grad norm.
# Params for grad norm.
self
.
main_grad_views_for_grad_norm
=
self
.
get_main_grad_views_for_grad_norm
(
self
.
main_grad_views_for_grad_norm
=
self
.
get_main_grad_views_for_grad_norm
(
self
.
opt_group_shards
,
self
.
opt_group_shards
,
...
@@ -293,47 +252,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -293,47 +252,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
get_model_parallel_group
(
self
):
def
get_model_parallel_group
(
self
):
return
None
return
None
# >>>
# @staticmethod
# def has_nan_debug(tensors):
# if isinstance(tensors, torch.Tensor):
# tensors = [ tensors ]
# assert isinstance(tensors, list)
# has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
# has_nan = any(has_nans)
# return has_nan
# def get_local_model_param_views(self):
# '''** FOR DEBUGGING. **'''
# model_param_views = []
# for group_index, opt_group_shard in enumerate(self.opt_group_shards):
# for param, opt_shard in opt_group_shard["param_map"].items():
# model_index, dtype = self.param_gbuf_map[param]
# gbuf_shard_map = \
# self.model_gbuf_shards[model_index][dtype]["param_map"][param]
# model_param_shard = gbuf_shard_map["param"]
# model_param_views.append(
# param.view(-1)[model_param_shard.start:model_param_shard.end])
# return model_param_views
# def get_local_model_grad_views(self):
# '''** FOR DEBUGGING. **'''
# model_grad_views = []
# for group_index, opt_group_shard in enumerate(self.opt_group_shards):
# for param, opt_shard in opt_group_shard["param_map"].items():
# model_index, dtype = self.param_gbuf_map[param]
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf_shard_map = \
# self.model_gbuf_shards[model_index][dtype]["param_map"][param]
# gbuf_world_shard = gbuf_shard_map["gbuf_world"]
# model_grad_views.append(
# gbuf[gbuf_world_shard.start:gbuf_world_shard.end])
# return model_grad_views
# def get_world_model_params(self):
# '''** FOR DEBUGGING. **'''
# return [ p for m in self.models for p in m.parameters() ]
# def get_world_model_grads(self):
# '''** FOR DEBUGGING. **'''
# return [ p.main_grad for p in self.get_world_model_params() ]
# <<<
def
get_main_params
(
self
):
def
get_main_params
(
self
):
return
[
g
[
"params"
][
0
]
for
g
in
self
.
optimizer
.
param_groups
]
return
[
g
[
"params"
][
0
]
for
g
in
self
.
optimizer
.
param_groups
]
...
@@ -344,10 +262,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -344,10 +262,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
get_main_grad
(
self
,
group_index
):
def
get_main_grad
(
self
,
group_index
):
return
self
.
get_main_param
(
group_index
).
grad
return
self
.
get_main_param
(
group_index
).
grad
# >>>
def
_
get_main_grads_for_grad_norm
(
self
):
def
get_main_grads_for_grad_norm
(
self
):
return
self
.
main_grad_views_for_grad_norm
return
self
.
main_grad_views_for_grad_norm
# <<<
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{}
state_dict
=
{}
...
@@ -386,6 +304,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -386,6 +304,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
current_param
.
data
.
copy_
(
saved_param
.
data
)
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
# Collect model params.
# Collect model params.
...
@@ -397,6 +316,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -397,6 +316,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Distributed optimizer requires contiguous buffer; don't set to None.
# Distributed optimizer requires contiguous buffer; don't set to None.
_zero_grad_group_helper
(
model_params
,
set_to_none
=
False
)
_zero_grad_group_helper
(
model_params
,
set_to_none
=
False
)
def
get_model_grad_buffer_dp_views
(
self
):
def
get_model_grad_buffer_dp_views
(
self
):
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
...
@@ -410,53 +330,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -410,53 +330,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_size
=
int
(
gbuf
.
numel_padded
/
data_parallel_world_size
)
shard_size
=
int
(
gbuf
.
numel_padded
/
data_parallel_world_size
)
gbuf_views
=
[
gbuf
.
data
[(
r
*
shard_size
):((
r
+
1
)
*
shard_size
)]
gbuf_views
=
[
gbuf
.
data
[(
r
*
shard_size
):((
r
+
1
)
*
shard_size
)]
for
r
in
range
(
data_parallel_world_size
)]
for
r
in
range
(
data_parallel_world_size
)]
# gbuf_view_items.append((model_index, dtype, gbuf_views))
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf
.
data
,
gbuf_views
))
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf
.
data
,
gbuf_views
))
return
gbuf_view_items
return
gbuf_view_items
# >>>
# def get_model_grad_buffer_dp_views_SINGLE(self):
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Grad buffer views.
# gbuf_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf in model._grad_buffers.items():
# gbuf_items.append((model_index, dtype, gbuf.data))
# return gbuf_items
# <<<
# >>>
# def get_model_grad_buffer_dp_views_chunked(self, mem_savings_factor):
# # Iterate grad buffers & chunk.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# chunk_view_items = []
# for model_index, dtype, gbuf_views in gbuf_view_items:
# # ** Sanity check. ** (should be unnecessary; see comment above)
# view_numel = gbuf_views[0].nelement()
# for view in gbuf_views:
# assert view.nelement() == view_numel
# # Compute chunk size (via savings factor).
# chunk_numel_min = 131072
# chunk_numel_max = view_numel
# chunk_numel = int(
# mem_savings_factor * chunk_numel_min
# + (1 - mem_savings_factor) * chunk_numel_max
# )
# # Chunk views.
# for start_index in range(0, view_numel, chunk_numel):
# end_index = min(view_numel, start_index + chunk_numel)
# chunk_views = [ t[start_index:end_index] for t in gbuf_views ]
# chunk_view_items.append((model_index, dtype, chunk_views))
# return chunk_view_items
# <<<
def
reduce_model_grads
(
self
,
args
,
timers
):
def
reduce_model_grads
(
self
,
args
,
timers
):
'''Note: this is a different order of reduction, versus the non-
'''Note: this is a different order of reduction, versus the non-
...
@@ -474,44 +350,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -474,44 +350,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
mem_savings_factor
=
args
.
distrib_opt_comm_mem_savings
# Scale grad buffers by '1 / data_parallel_world_size'.
# Scale grad buffers by '1 / data_parallel_world_size'.
for
model
in
self
.
models
:
for
model
in
self
.
models
:
for
dtype
,
gbuf
in
model
.
_grad_buffers
.
items
():
for
dtype
,
gbuf
in
model
.
_grad_buffers
.
items
():
gbuf
.
data
/=
data_parallel_world_size
gbuf
.
data
/=
data_parallel_world_size
# Reduce scatter all grads.
# Reduce-scatter all grads.
# >>>
# gbuf_view_items = \
# self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
# for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# +++
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
# gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
# >>>
# pax(0, {
# "gbuf_view" : gbuf_views[data_parallel_rank].shape,
# "gbuf SINGLE" : gbuf_view_items_SINGLE[index][2].shape,
# })
# <<<
torch
.
distributed
.
_reduce_scatter_base
(
torch
.
distributed
.
_reduce_scatter_base
(
gbuf_views
[
data_parallel_rank
],
gbuf_views
[
data_parallel_rank
],
gbuf
,
# gbuf_view_items_SINGLE[index][2],
gbuf
,
group
=
data_parallel_group
,
group
=
data_parallel_group
,
)
)
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# <<<
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
):
def
gather_model_params
(
self
,
args
,
timers
):
...
@@ -520,32 +373,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -520,32 +373,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
mem_savings_factor
=
args
.
distrib_opt_comm_mem_savings
# All-gather updated main params.
# All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements
# - All grad buffer views are guaranteed to have the same num elements
# across all data parallel ranks, with grad buffer padding that is done
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
# indexes across data parallel ranks.
# >>>
# gbuf_view_items = \
# self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
# for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.all_gather(
# gbuf_views,
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
# )
# +++
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
# gbuf_view_items_SINGLE = self.get_model_grad_buffer_dp_views_SINGLE()
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
torch
.
distributed
.
_all_gather_base
(
torch
.
distributed
.
_all_gather_base
(
gbuf
,
# gbuf_view_items_SINGLE[index][2],
gbuf
,
gbuf_views
[
data_parallel_rank
],
gbuf_views
[
data_parallel_rank
],
group
=
data_parallel_group
,
group
=
data_parallel_group
,
)
)
# <<<
# Each model param now contains its updated values in its
# Each model param now contains its updated values in its
# '.main_grad' field.
# '.main_grad' field.
...
...
megatron/optimizer/optimizer.py
View file @
11581195
...
@@ -27,23 +27,11 @@ from megatron import mpu
...
@@ -27,23 +27,11 @@ from megatron import mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
Float16Module
from
megatron.utils
import
unwrap_model
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
# >>>
from
megatron.model.module
import
param_is_not_shared
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
megatron.utils
import
unwrap_model
from
lutil
import
pax
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
get_clippy
=
lambda
params
:
[
"%d, %d, %d ... %s"
%
(
p
.
grad
is
not
None
,
param_is_not_shared
(
p
),
param_is_not_tensor_parallel_duplicate
(
p
),
str
(
tuple
(
p
.
shape
)),
)
for
p
in
params
]
# <<<
def
_zero_grad_group_helper
(
group
,
set_to_none
):
def
_zero_grad_group_helper
(
group
,
set_to_none
):
...
@@ -112,12 +100,9 @@ class MegatronOptimizer(ABC):
...
@@ -112,12 +100,9 @@ class MegatronOptimizer(ABC):
params
.
append
(
param
)
params
.
append
(
param
)
return
params
return
params
# >>>
@
abstractmethod
@
abstractmethod
# def get_grads_for_norm(self):
def
get_main_grads_for_grad_norm
(
self
):
def
_get_main_grads_for_grad_norm
(
self
):
pass
pass
# <<<
def
get_model_parallel_group
(
self
):
def
get_model_parallel_group
(
self
):
'''Default returned here, but the distributed optimizer overrides this.'''
'''Default returned here, but the distributed optimizer overrides this.'''
...
@@ -126,7 +111,7 @@ class MegatronOptimizer(ABC):
...
@@ -126,7 +111,7 @@ class MegatronOptimizer(ABC):
def
clip_grad_norm
(
self
,
clip_grad
):
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
params
=
self
.
get_parameters
()
grads_for_norm
=
self
.
_
get_main_grads_for_grad_norm
()
grads_for_norm
=
self
.
get_main_grads_for_grad_norm
()
return
clip_grad_norm_fp32
(
return
clip_grad_norm_fp32
(
params
,
grads_for_norm
,
clip_grad
,
params
,
grads_for_norm
,
clip_grad
,
model_parallel_group
=
self
.
get_model_parallel_group
())
model_parallel_group
=
self
.
get_model_parallel_group
())
...
@@ -544,17 +529,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
...
@@ -544,17 +529,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# recast preexisting per-param state tensors
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# >>>
# model_params = [ p for m in self.models for p in m.parameters() ]
# optim_params = self.get_parameters()
# model_params.sort(key = lambda p : p.nelement(), reverse = True)
# optim_params.sort(key = lambda p : p.nelement(), reverse = True)
# # assert len(model_params) == len(optim_params
# pax(7, {
# "model_params" : get_clippy(model_params),
# "optim_params" : get_clippy(optim_params),
# })
# <<<
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
"""We only need to zero the model related parameters, i.e.,
...
@@ -570,7 +544,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
...
@@ -570,7 +544,7 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
_zero_grad_group_helper
(
group
,
set_to_none
)
_zero_grad_group_helper
(
group
,
set_to_none
)
def
_
get_main_grads_for_grad_norm
(
self
):
def
get_main_grads_for_grad_norm
(
self
):
# Filter parameters based on:
# Filter parameters based on:
# - grad should not be none
# - grad should not be none
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment