Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
87833e1f
Commit
87833e1f
authored
Nov 30, 2020
by
Jeff Rasley
Browse files
calculate grad norm wrt sub partitions
parent
17f36f1b
Pipeline
#200
failed with stages
in 0 seconds
Changes
2
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
8 deletions
+48
-8
deepspeed/runtime/utils.py
deepspeed/runtime/utils.py
+7
-6
deepspeed/runtime/zero/stage1.py
deepspeed/runtime/zero/stage1.py
+41
-2
No files found.
deepspeed/runtime/utils.py
View file @
87833e1f
...
...
@@ -67,7 +67,7 @@ class CheckOverflow(object):
return
bool
(
overflow
)
def
check
(
self
,
param_groups
=
None
):
def
check
(
self
,
param_groups
=
None
,
raw_grads
=
False
):
params
=
[]
if
param_groups
is
None
:
params
=
self
.
params
...
...
@@ -79,17 +79,18 @@ class CheckOverflow(object):
for
param
in
group
:
params
.
append
(
param
)
return
self
.
has_overflow
(
params
)
return
self
.
has_overflow
(
params
,
raw_grads
)
# `params` is a list / generator of torch.Variable
def
has_overflow_serial
(
self
,
params
):
def
has_overflow_serial
(
self
,
params
,
raw_grads
=
False
):
for
i
,
p
in
enumerate
(
params
):
if
p
.
grad
is
not
None
and
self
.
_has_inf_or_nan
(
p
.
grad
.
data
,
i
):
grad
=
p
if
raw_grads
else
p
.
grad
if
grad
is
not
None
and
self
.
_has_inf_or_nan
(
grad
.
data
,
i
):
return
True
return
False
def
has_overflow
(
self
,
params
):
overflow
=
self
.
has_overflow_serial
(
params
)
def
has_overflow
(
self
,
params
,
raw_grads
=
False
):
overflow
=
self
.
has_overflow_serial
(
params
,
raw_grads
)
# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
overflow_gpu
=
torch
.
cuda
.
ByteTensor
([
overflow
])
...
...
deepspeed/runtime/zero/stage1.py
View file @
87833e1f
...
...
@@ -6,7 +6,7 @@ from collections import defaultdict
from
deepspeed.runtime.zero.utils
import
_initialize_parameter_parallel_groups
from
deepspeed.runtime.fp16.loss_scaler
import
LossScaler
,
DynamicLossScaler
from
deepspeed.runtime.utils
import
get_grad_norm
,
CheckOverflow
from
deepspeed.runtime.utils
import
get_grad_norm
,
CheckOverflow
,
is_model_parallel_parameter
from
deepspeed.runtime.zero.config
import
ZERO_OPTIMIZATION_OPTIMIZER_STATES
from
deepspeed.utils
import
logger
,
log_dist
...
...
@@ -642,7 +642,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
partition_id
=
dist
.
get_rank
(
group
=
self
.
dp_process_group
)
for
i
,
group
in
enumerate
(
self
.
fp16_groups
):
#TODO RS: update get grad norm to support sub partitions
norm_groups
.
append
(
get_grad_norm
(
group
,
mpu
=
self
.
mpu
))
#
norm_groups.append(get_grad_norm(group, mpu=self.mpu))
#RS: update free grads w.r.t. sub partitions
#free gradients for all the parameters that are not updated by this process
...
...
@@ -667,6 +667,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
self
.
free_grad_in_param_list
(
self
.
params_in_rank_sub_partitions
[
i
][
partition_id
])
# calculate grad norm w.r.t. local sub partitions
norm_groups
.
append
(
self
.
get_grad_norm_sub_partitions
(
local_grad_sub_partitions
,
mpu
=
self
.
mpu
))
local_sub_partitions_grad_groups
.
append
(
local_grad_sub_partitions
)
#RS: update unscale/clip with sub partitions
...
...
@@ -706,6 +711,40 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
return
self
.
overflow
def
get_grad_norm_sub_partitions
(
self
,
sub_partitions
,
mpu
):
norm_type
=
2.0
total_norm
=
0.
for
partition
in
sub_partitions
:
if
mpu
is
not
None
:
# if (mpu.get_model_parallel_rank() == 0
# ) or is_model_parallel_parameter(p):
# param_norm = p.grad.data.float().norm(norm_type)
# total_norm += param_norm.item()**norm_type
raise
NotImplementedError
(
"support grad norm of model parallel parameters"
)
else
:
param_norm
=
partition
.
data
.
float
().
norm
(
norm_type
)
total_norm
+=
param_norm
.
item
()
**
norm_type
# Sum across all DP ranks who each have different grad sub-partitions
total_norm_cuda
=
torch
.
cuda
.
FloatTensor
([
float
(
total_norm
)])
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
self
.
dp_process_group
)
if
mpu
is
not
None
:
# Sum across all model parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_model_parallel_group
())
total_norm
=
total_norm_cuda
[
0
].
item
()
**
(
1.
/
norm_type
)
if
total_norm
==
float
(
'inf'
)
or
total_norm
==
-
float
(
'inf'
)
or
total_norm
!=
total_norm
:
total_norm
=
-
1
return
total_norm
def
unscale_and_clip_grads
(
self
,
grad_groups_flat
,
norm_groups
):
total_norm
=
0.0
for
norm
in
norm_groups
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment