Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
84a3bdbe
Unverified
Commit
84a3bdbe
authored
Jan 01, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 01, 2021
Browse files
[fix] Typo in ShardedDDP unit test (#282)
* fix typo, backend for CPU test
parent
1c8d219d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
13 deletions
+8
-13
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+6
-7
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+2
-6
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
84a3bdbe
...
...
@@ -129,7 +129,7 @@ class ShardedDataParallel(nn.Module):
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
reduce
(
self
)
->
None
:
"""
.. deprecated:: 0.0.4
""".. deprecated:: 0.0.4
This does not need to be called, the gradient reduction is done automatically during the BW pass
"""
...
...
@@ -157,8 +157,7 @@ class ShardedDataParallel(nn.Module):
self
.
should_accumulate_grads
=
old_should_accumulate_grads
def
_clear_counters
(
self
)
->
None
:
""" Reset all the grad reduce and call counters
"""
"""Reset all the grad reduce and call counters"""
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
...
...
@@ -254,14 +253,14 @@ class ShardedDataParallel(nn.Module):
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
def
_passing_sync_batchnorm_handle
(
self
,
module
)
:
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
"""
Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
Adapted from ``torch.nn.distributed.DistributedDataParallel``.
"""
for
layer
in
module
.
modules
():
if
isinstance
(
layer
,
torch
.
nn
.
modules
.
SyncBatchNorm
):
assert
self
.
device_type
!=
'
cpu
'
,
"SyncBatchNorm layers only work with GPU modules"
assert
self
.
device_type
!=
"
cpu
"
,
"SyncBatchNorm layers only work with GPU modules"
# device_id logic has not been handled, assume single-process single-device
# SyncBatchNorm only supports DDP with single-process single-device anyway'
layer
.
_specify_ddp_gpu_num
(
1
)
layer
.
_specify_ddp_gpu_num
(
1
)
# type: ignore
tests/nn/data_parallel/test_sharded_ddp.py
View file @
84a3bdbe
...
...
@@ -316,8 +316,7 @@ def test_ddp_attributes():
# - device_type
url
=
"file://"
+
tempfile
.
mkstemp
()[
1
]
backend
=
dist
.
Backend
.
NCCL
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
url
,
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
0.01
,
momentum
=
0.99
)
...
...
@@ -352,10 +351,7 @@ def test_ddp_sync_batch_norm():
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
mp
.
spawn
(
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment