Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
84a3bdbe
Unverified
Commit
84a3bdbe
authored
Jan 01, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 01, 2021
Browse files
[fix] Typo in ShardedDDP unit test (#282)
* fix typo, backend for CPU test
parent
1c8d219d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
13 deletions
+8
-13
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+6
-7
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+2
-6
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
84a3bdbe
...
...
@@ -129,9 +129,9 @@ class ShardedDataParallel(nn.Module):
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
reduce
(
self
)
->
None
:
"""
.. deprecated:: 0.0.4
""".. deprecated:: 0.0.4
This does not need to be called, the gradient reduction is done automatically during the BW pass
This does not need to be called, the gradient reduction is done automatically during the BW pass
"""
logging
.
warning
(
"This is not useful anymore, gradients have been reduced automatically with the backward pass"
)
...
...
@@ -157,8 +157,7 @@ class ShardedDataParallel(nn.Module):
self
.
should_accumulate_grads
=
old_should_accumulate_grads
def
_clear_counters
(
self
)
->
None
:
""" Reset all the grad reduce and call counters
"""
"""Reset all the grad reduce and call counters"""
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
...
...
@@ -254,14 +253,14 @@ class ShardedDataParallel(nn.Module):
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
def
_passing_sync_batchnorm_handle
(
self
,
module
)
:
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
"""
Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
Adapted from ``torch.nn.distributed.DistributedDataParallel``.
"""
for
layer
in
module
.
modules
():
if
isinstance
(
layer
,
torch
.
nn
.
modules
.
SyncBatchNorm
):
assert
self
.
device_type
!=
'
cpu
'
,
"SyncBatchNorm layers only work with GPU modules"
assert
self
.
device_type
!=
"
cpu
"
,
"SyncBatchNorm layers only work with GPU modules"
# device_id logic has not been handled, assume single-process single-device
# SyncBatchNorm only supports DDP with single-process single-device anyway'
layer
.
_specify_ddp_gpu_num
(
1
)
layer
.
_specify_ddp_gpu_num
(
1
)
# type: ignore
tests/nn/data_parallel/test_sharded_ddp.py
View file @
84a3bdbe
...
...
@@ -316,8 +316,7 @@ def test_ddp_attributes():
# - device_type
url
=
"file://"
+
tempfile
.
mkstemp
()[
1
]
backend
=
dist
.
Backend
.
NCCL
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
url
,
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
0.01
,
momentum
=
0.99
)
...
...
@@ -352,10 +351,7 @@ def test_ddp_sync_batch_norm():
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment