Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
376818ef
Unverified
Commit
376818ef
authored
Jul 15, 2020
by
Jeff Rasley
Committed by
GitHub
Jul 15, 2020
Browse files
Empty grad fix (#291)
* empty grad fix * add unit tests for empty grad
parent
607814fe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
101 additions
and
9 deletions
+101
-9
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+11
-1
tests/unit/simple_model.py
tests/unit/simple_model.py
+8
-3
tests/unit/test_fp16.py
tests/unit/test_fp16.py
+82
-5
No files found.
deepspeed/pt/deepspeed_light.py
View file @
376818ef
...
@@ -979,7 +979,17 @@ class DeepSpeedLight(Module):
...
@@ -979,7 +979,17 @@ class DeepSpeedLight(Module):
def
buffered_allreduce_fallback
(
self
,
grads
=
None
,
elements_per_buffer
=
500000000
):
def
buffered_allreduce_fallback
(
self
,
grads
=
None
,
elements_per_buffer
=
500000000
):
grads
=
[]
grads
=
[]
for
param_name
,
param
in
self
.
module
.
named_parameters
():
for
param_name
,
param
in
self
.
module
.
named_parameters
():
if
param
.
grad
is
not
None
:
if
param
.
grad
is
None
:
# In cases where there is an imbalance of empty grads across
# ranks we must create empty grads, this will ensure that every
# rank is reducing the same size. In some cases it may make
# sense in the future to support the ability to average not
# w.r.t. world size but with a different value.
grads
.
append
(
torch
.
zeros
(
param
.
size
(),
dtype
=
param
.
dtype
,
device
=
param
.
device
))
else
:
grad_data
=
param
.
grad
.
data
grad_data
=
param
.
grad
.
data
if
self
.
sparse_gradients_enabled
(
if
self
.
sparse_gradients_enabled
(
)
and
param_name
in
self
.
csr_tensor_module_names
:
)
and
param_name
in
self
.
csr_tensor_module_names
:
...
...
tests/unit/simple_model.py
View file @
376818ef
...
@@ -5,16 +5,21 @@ import torch
...
@@ -5,16 +5,21 @@ import torch
class
SimpleModel
(
torch
.
nn
.
Module
):
class
SimpleModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
empty_grad
=
False
):
def
__init__
(
self
,
hidden_dim
,
empty_grad
=
False
,
rank
=
0
):
super
(
SimpleModel
,
self
).
__init__
()
super
(
SimpleModel
,
self
).
__init__
()
self
.
linear
=
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
self
.
linear
=
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
if
empty_grad
:
if
empty_grad
:
self
.
l
ayers2
=
torch
.
nn
.
ModuleList
([
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
])
self
.
l
inear2
=
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
)
self
.
cross_entropy_loss
=
torch
.
nn
.
CrossEntropyLoss
()
self
.
cross_entropy_loss
=
torch
.
nn
.
CrossEntropyLoss
()
self
.
rank
=
rank
self
.
empty_grad
=
empty_grad
def
forward
(
self
,
x
,
y
):
def
forward
(
self
,
x
,
y
):
hidden_dim
=
x
hidden_dim
=
x
hidden_dim
=
self
.
linear
(
hidden_dim
)
if
self
.
rank
==
0
and
self
.
empty_grad
:
hidden_dim
=
self
.
linear
(
hidden_dim
)
+
self
.
linear2
(
hidden_dim
)
else
:
hidden_dim
=
self
.
linear
(
hidden_dim
)
return
self
.
cross_entropy_loss
(
hidden_dim
,
y
)
return
self
.
cross_entropy_loss
(
hidden_dim
,
y
)
...
...
tests/unit/test_fp16.py
View file @
376818ef
...
@@ -33,9 +33,10 @@ def test_lamb_fp32_grad_clip(tmpdir):
...
@@ -33,9 +33,10 @@ def test_lamb_fp32_grad_clip(tmpdir):
data_loader
=
random_dataloader
(
model
=
model
,
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
50
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
device
=
model
.
device
,
dtype
=
torch
.
float
)
for
n
,
batch
in
enumerate
(
data_loader
):
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
]
.
float
()
,
batch
[
1
])
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
backward
(
loss
)
model
.
step
()
model
.
step
()
...
@@ -81,7 +82,7 @@ def test_lamb_fp16_basic(tmpdir):
...
@@ -81,7 +82,7 @@ def test_lamb_fp16_basic(tmpdir):
def
test_lamb_fp16_empty_grad
(
tmpdir
):
def
test_lamb_fp16_empty_grad
(
tmpdir
):
config_dict
=
{
config_dict
=
{
"train_batch_size"
:
1
,
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"optimizer"
:
{
"type"
:
"Lamb"
,
"type"
:
"Lamb"
,
...
@@ -97,9 +98,9 @@ def test_lamb_fp16_empty_grad(tmpdir):
...
@@ -97,9 +98,9 @@ def test_lamb_fp16_empty_grad(tmpdir):
args
=
args_from_dict
(
tmpdir
,
config_dict
)
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
True
)
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
True
,
rank
=
args
.
local_rank
)
@
distributed_test
(
world_size
=
[
1
])
@
distributed_test
(
world_size
=
[
2
])
def
_test_lamb_fp16_empty_grad
(
args
,
model
,
hidden_dim
):
def
_test_lamb_fp16_empty_grad
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model
=
model
,
...
@@ -116,6 +117,44 @@ def test_lamb_fp16_empty_grad(tmpdir):
...
@@ -116,6 +117,44 @@ def test_lamb_fp16_empty_grad(tmpdir):
_test_lamb_fp16_empty_grad
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
_test_lamb_fp16_empty_grad
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
def
test_adam_fp32_empty_grad
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
}
},
"gradient_clipping"
:
1.0
,
"fp16"
:
{
"enabled"
:
False
}
}
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
True
,
rank
=
args
.
local_rank
)
@
distributed_test
(
world_size
=
[
2
])
def
_test_adam_fp32_empty_grad
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
())
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
,
dtype
=
torch
.
float
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
step
()
_test_adam_fp32_empty_grad
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
def
test_adamw_fp16_basic
(
tmpdir
):
def
test_adamw_fp16_basic
(
tmpdir
):
config_dict
=
{
config_dict
=
{
"train_batch_size"
:
1
,
"train_batch_size"
:
1
,
...
@@ -495,3 +534,41 @@ def test_adam_amp_o2(tmpdir):
...
@@ -495,3 +534,41 @@ def test_adam_amp_o2(tmpdir):
model
.
step
()
model
.
step
()
_test_adam_amp_o2
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
_test_adam_amp_o2
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
def
test_adam_amp_o2_empty_grad
(
tmpdir
):
config_dict
=
{
"train_batch_size"
:
2
,
"steps_per_print"
:
1
,
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.00015
}
},
"gradient_clipping"
:
1.0
,
"amp"
:
{
"enabled"
:
True
,
"opt_level"
:
"O2"
}
}
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
,
rank
=
args
.
local_rank
)
@
distributed_test
(
world_size
=
[
2
])
def
_test_adam_amp_o2_empty_grad
(
args
,
model
,
hidden_dim
):
model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
())
data_loader
=
random_dataloader
(
model
=
model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
device
=
model
.
device
)
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
model
(
batch
[
0
],
batch
[
1
])
model
.
backward
(
loss
)
model
.
step
()
_test_adam_amp_o2_empty_grad
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment