Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
121b9db0
Unverified
Commit
121b9db0
authored
Apr 06, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 06, 2021
Browse files
[fix][OSS] two small hotfixes.. repro not obvious for grad_fn (#583)
parent
168c9baa
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
19 deletions
+57
-19
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+16
-10
fairscale/optim/oss.py
fairscale/optim/oss.py
+6
-5
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+33
-2
tests/optim/test_oss.py
tests/optim/test_oss.py
+2
-2
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
121b9db0
...
@@ -494,18 +494,24 @@ class ShardedDataParallel(nn.Module):
...
@@ -494,18 +494,24 @@ class ShardedDataParallel(nn.Module):
if
param
.
grad
is
not
None
and
param
.
grad
.
requires_grad
:
if
param
.
grad
is
not
None
and
param
.
grad
.
requires_grad
:
raise
RuntimeError
(
"ShardedDataParallel only works with gradients that don't require grad"
)
raise
RuntimeError
(
"ShardedDataParallel only works with gradients that don't require grad"
)
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
p_tmp
=
param
.
expand_as
(
param
)
p_tmp
=
param
.
expand_as
(
param
)
assert
p_tmp
.
grad_fn
is
not
None
grad_acc
=
p_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
reduce_function
=
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
)
self
.
_grad_hooks
.
append
(
grad_acc
.
register_hook
(
reduce_function
))
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this hook in scope
# We're interested in the tensors which will be tracked by Autograd
self
.
_manual_reduce
.
append
(
reduce_function
)
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# these do not need to be sync'ed
if
p_tmp
.
grad_fn
is
not
None
:
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
# (by default the hook with Pytorch is a pre-grad, not a post-grad)
grad_acc
=
p_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
reduce_function
=
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
)
self
.
_grad_hooks
.
append
(
grad_acc
.
register_hook
(
reduce_function
))
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this hook in scope
self
.
_manual_reduce
.
append
(
reduce_function
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_sync_params_and_buffers
(
self
)
->
None
:
def
_sync_params_and_buffers
(
self
)
->
None
:
...
...
fairscale/optim/oss.py
View file @
121b9db0
...
@@ -591,13 +591,14 @@ class OSS(Optimizer):
...
@@ -591,13 +591,14 @@ class OSS(Optimizer):
# Merge all the trainable params in a single bucket
# Merge all the trainable params in a single bucket
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
params
))
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
params
))
buffer_size
=
sum
(
map
(
lambda
x
:
x
.
numel
(),
trainable_params
))
if
trainable_params
:
bucket
=
ParamBucket
(
size
=
buffer_size
,
dtype
=
params
[
0
].
dtype
,
device
=
device
)
buffer_size
=
sum
(
map
(
lambda
x
:
x
.
numel
(),
trainable_params
))
bucket
=
ParamBucket
(
size
=
buffer_size
,
dtype
=
trainable_params
[
0
].
dtype
,
device
=
device
)
for
param
in
trainable_params
:
for
param
in
trainable_params
:
bucket
.
add_param
(
param
)
bucket
.
add_param
(
param
)
self
.
buckets
[
device
][
dst_rank
]
=
bucket
self
.
buckets
[
device
][
dst_rank
]
=
bucket
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use
=
list
(
self
.
_per_device_params
.
keys
())
devices_in_use
=
list
(
self
.
_per_device_params
.
keys
())
...
...
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
121b9db0
...
@@ -30,8 +30,25 @@ from fairscale.utils.testing import (
...
@@ -30,8 +30,25 @@ from fairscale.utils.testing import (
)
)
def
_get_mlp
():
def
_get_mlp
(
tripwire
:
bool
=
False
):
return
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
))
if
not
tripwire
:
return
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
))
class
Tripwire
(
torch
.
nn
.
Module
):
"""A model made to expose possible corner cases
"""
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
model
=
Linear
(
2
,
3
,
bias
=
False
)
# mismatched types in between trainable or not, can trip the buckets for instance
self
.
register_parameter
(
"tripwire"
,
torch
.
nn
.
Parameter
(
torch
.
LongTensor
((
3
,
3
)),
requires_grad
=
False
))
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
return
Tripwire
()
class
_DoubleInput
(
torch
.
nn
.
Module
):
class
_DoubleInput
(
torch
.
nn
.
Module
):
...
@@ -231,6 +248,20 @@ def test_random_attributes():
...
@@ -231,6 +248,20 @@ def test_random_attributes():
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
test_mixed_types
():
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
_get_mlp
(
tripwire
=
True
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
model
=
ShardedDataParallel
(
model
,
optimizer
)
input_tensor
=
torch
.
rand
((
2
,
2
))
_
=
model
(
input_tensor
)
dist
.
destroy_process_group
()
def
run_test_device_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
def
run_test_device_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
# Check that the wrapped module can change devices
# Check that the wrapped module can change devices
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
tests/optim/test_oss.py
View file @
121b9db0
...
@@ -322,8 +322,8 @@ def run_test_step(rank, world_size, tempfile_name):
...
@@ -322,8 +322,8 @@ def run_test_step(rank, world_size, tempfile_name):
dist
.
all_reduce
(
p
.
grad
.
data
,
op
=
dist
.
ReduceOp
.
SUM
)
dist
.
all_reduce
(
p
.
grad
.
data
,
op
=
dist
.
ReduceOp
.
SUM
)
p
.
grad
.
data
/=
world_size
p
.
grad
.
data
/=
world_size
o
.
step
()
o
.
step
()
assert
m
.
weight
==
torch
.
tensor
([[
0.75
]],
device
=
rank
)
assert
m
.
weight
==
torch
.
tensor
([[
0.75
]],
device
=
rank
)
,
f
"
{
rank
}
:
{
m
.
weight
.
item
()
}
, 0.75 expected"
assert
m
.
bias
==
torch
.
tensor
([
1.85
],
device
=
rank
)
assert
m
.
bias
==
torch
.
tensor
([
1.85
],
device
=
rank
)
,
f
"
{
rank
}
:
{
m
.
bias
.
item
()
}
, 1.85 expected"
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment