Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
121b9db0
Unverified
Commit
121b9db0
authored
Apr 06, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Apr 06, 2021
Browse files
[fix][OSS] two small hotfixes.. repro not obvious for grad_fn (#583)
parent
168c9baa
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
19 deletions
+57
-19
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+16
-10
fairscale/optim/oss.py
fairscale/optim/oss.py
+6
-5
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_features.py
+33
-2
tests/optim/test_oss.py
tests/optim/test_oss.py
+2
-2
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
121b9db0
...
@@ -494,18 +494,24 @@ class ShardedDataParallel(nn.Module):
...
@@ -494,18 +494,24 @@ class ShardedDataParallel(nn.Module):
if
param
.
grad
is
not
None
and
param
.
grad
.
requires_grad
:
if
param
.
grad
is
not
None
and
param
.
grad
.
requires_grad
:
raise
RuntimeError
(
"ShardedDataParallel only works with gradients that don't require grad"
)
raise
RuntimeError
(
"ShardedDataParallel only works with gradients that don't require grad"
)
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
p_tmp
=
param
.
expand_as
(
param
)
p_tmp
=
param
.
expand_as
(
param
)
assert
p_tmp
.
grad_fn
is
not
None
grad_acc
=
p_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
reduce_function
=
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
)
self
.
_grad_hooks
.
append
(
grad_acc
.
register_hook
(
reduce_function
))
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this hook in scope
# We're interested in the tensors which will be tracked by Autograd
self
.
_manual_reduce
.
append
(
reduce_function
)
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# these do not need to be sync'ed
if
p_tmp
.
grad_fn
is
not
None
:
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
# (by default the hook with Pytorch is a pre-grad, not a post-grad)
grad_acc
=
p_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
dst_rank
=
self
.
_trainable_param_to_rank
[
param
]
reduce_function
=
self
.
_get_reduce_fn
(
index
,
param
,
dst_rank
)
self
.
_grad_hooks
.
append
(
grad_acc
.
register_hook
(
reduce_function
))
self
.
_grad_accs
.
append
(
grad_acc
)
# keep this hook in scope
self
.
_manual_reduce
.
append
(
reduce_function
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_sync_params_and_buffers
(
self
)
->
None
:
def
_sync_params_and_buffers
(
self
)
->
None
:
...
...
fairscale/optim/oss.py
View file @
121b9db0
...
@@ -591,13 +591,14 @@ class OSS(Optimizer):
...
@@ -591,13 +591,14 @@ class OSS(Optimizer):
# Merge all the trainable params in a single bucket
# Merge all the trainable params in a single bucket
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
params
))
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
params
))
buffer_size
=
sum
(
map
(
lambda
x
:
x
.
numel
(),
trainable_params
))
if
trainable_params
:
bucket
=
ParamBucket
(
size
=
buffer_size
,
dtype
=
params
[
0
].
dtype
,
device
=
device
)
buffer_size
=
sum
(
map
(
lambda
x
:
x
.
numel
(),
trainable_params
))
bucket
=
ParamBucket
(
size
=
buffer_size
,
dtype
=
trainable_params
[
0
].
dtype
,
device
=
device
)
for
param
in
trainable_params
:
for
param
in
trainable_params
:
bucket
.
add_param
(
param
)
bucket
.
add_param
(
param
)
self
.
buckets
[
device
][
dst_rank
]
=
bucket
self
.
buckets
[
device
][
dst_rank
]
=
bucket
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use
=
list
(
self
.
_per_device_params
.
keys
())
devices_in_use
=
list
(
self
.
_per_device_params
.
keys
())
...
...
tests/nn/data_parallel/test_sharded_ddp_features.py
View file @
121b9db0
...
@@ -30,8 +30,25 @@ from fairscale.utils.testing import (
...
@@ -30,8 +30,25 @@ from fairscale.utils.testing import (
)
)
def
_get_mlp
():
def
_get_mlp
(
tripwire
:
bool
=
False
):
return
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
))
if
not
tripwire
:
return
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
),
Linear
(
3
,
3
))
class
Tripwire
(
torch
.
nn
.
Module
):
"""A model made to expose possible corner cases
"""
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
model
=
Linear
(
2
,
3
,
bias
=
False
)
# mismatched types in between trainable or not, can trip the buckets for instance
self
.
register_parameter
(
"tripwire"
,
torch
.
nn
.
Parameter
(
torch
.
LongTensor
((
3
,
3
)),
requires_grad
=
False
))
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
return
Tripwire
()
class
_DoubleInput
(
torch
.
nn
.
Module
):
class
_DoubleInput
(
torch
.
nn
.
Module
):
...
@@ -231,6 +248,20 @@ def test_random_attributes():
...
@@ -231,6 +248,20 @@ def test_random_attributes():
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
test_mixed_types
():
# Check that ShardedDDP exposes the original module's attributes
dist
.
init_process_group
(
init_method
=
"file://"
+
tempfile
.
mkstemp
()[
1
],
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
_get_mlp
(
tripwire
=
True
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
1e-3
,
momentum
=
0.99
)
model
=
ShardedDataParallel
(
model
,
optimizer
)
input_tensor
=
torch
.
rand
((
2
,
2
))
_
=
model
(
input_tensor
)
dist
.
destroy_process_group
()
def
run_test_device_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
def
run_test_device_change
(
rank
,
world_size
,
backend
,
device
,
temp_file_name
,
reduce_buffer_size
):
# Check that the wrapped module can change devices
# Check that the wrapped module can change devices
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
init_method
=
"file://"
+
temp_file_name
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
tests/optim/test_oss.py
View file @
121b9db0
...
@@ -322,8 +322,8 @@ def run_test_step(rank, world_size, tempfile_name):
...
@@ -322,8 +322,8 @@ def run_test_step(rank, world_size, tempfile_name):
dist
.
all_reduce
(
p
.
grad
.
data
,
op
=
dist
.
ReduceOp
.
SUM
)
dist
.
all_reduce
(
p
.
grad
.
data
,
op
=
dist
.
ReduceOp
.
SUM
)
p
.
grad
.
data
/=
world_size
p
.
grad
.
data
/=
world_size
o
.
step
()
o
.
step
()
assert
m
.
weight
==
torch
.
tensor
([[
0.75
]],
device
=
rank
)
assert
m
.
weight
==
torch
.
tensor
([[
0.75
]],
device
=
rank
)
,
f
"
{
rank
}
:
{
m
.
weight
.
item
()
}
, 0.75 expected"
assert
m
.
bias
==
torch
.
tensor
([
1.85
],
device
=
rank
)
assert
m
.
bias
==
torch
.
tensor
([
1.85
],
device
=
rank
)
,
f
"
{
rank
}
:
{
m
.
bias
.
item
()
}
, 1.85 expected"
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment