Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
84a3bdbe
Unverified
Commit
84a3bdbe
authored
Jan 01, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 01, 2021
Browse files
[fix] Typo in ShardedDDP unit test (#282)
* fix typo, backend for CPU test
parent
1c8d219d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
13 deletions
+8
-13
fairscale/nn/data_parallel/sharded_ddp.py
fairscale/nn/data_parallel/sharded_ddp.py
+6
-7
tests/nn/data_parallel/test_sharded_ddp.py
tests/nn/data_parallel/test_sharded_ddp.py
+2
-6
No files found.
fairscale/nn/data_parallel/sharded_ddp.py
View file @
84a3bdbe
...
@@ -129,7 +129,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -129,7 +129,7 @@ class ShardedDataParallel(nn.Module):
return
self
.
module
(
*
inputs
,
**
kwargs
)
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
reduce
(
self
)
->
None
:
def
reduce
(
self
)
->
None
:
"""
.. deprecated:: 0.0.4
""".. deprecated:: 0.0.4
This does not need to be called, the gradient reduction is done automatically during the BW pass
This does not need to be called, the gradient reduction is done automatically during the BW pass
"""
"""
...
@@ -157,8 +157,7 @@ class ShardedDataParallel(nn.Module):
...
@@ -157,8 +157,7 @@ class ShardedDataParallel(nn.Module):
self
.
should_accumulate_grads
=
old_should_accumulate_grads
self
.
should_accumulate_grads
=
old_should_accumulate_grads
def
_clear_counters
(
self
)
->
None
:
def
_clear_counters
(
self
)
->
None
:
""" Reset all the grad reduce and call counters
"""Reset all the grad reduce and call counters"""
"""
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_grad_to_be_reduced
=
[
True
for
_
in
self
.
_grad_to_be_reduced
]
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
self
.
_reduced_grads
=
{
o
:
0
for
o
in
self
.
sharded_optimizers
}
...
@@ -254,14 +253,14 @@ class ShardedDataParallel(nn.Module):
...
@@ -254,14 +253,14 @@ class ShardedDataParallel(nn.Module):
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
work_handles
))
def
_passing_sync_batchnorm_handle
(
self
,
module
)
:
def
_passing_sync_batchnorm_handle
(
self
,
module
:
nn
.
Module
)
->
None
:
"""
"""
Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
Passes handle required for ``torch.nn.modules.SyncBatchNorm``.
Adapted from ``torch.nn.distributed.DistributedDataParallel``.
Adapted from ``torch.nn.distributed.DistributedDataParallel``.
"""
"""
for
layer
in
module
.
modules
():
for
layer
in
module
.
modules
():
if
isinstance
(
layer
,
torch
.
nn
.
modules
.
SyncBatchNorm
):
if
isinstance
(
layer
,
torch
.
nn
.
modules
.
SyncBatchNorm
):
assert
self
.
device_type
!=
'
cpu
'
,
"SyncBatchNorm layers only work with GPU modules"
assert
self
.
device_type
!=
"
cpu
"
,
"SyncBatchNorm layers only work with GPU modules"
# device_id logic has not been handled, assume single-process single-device
# device_id logic has not been handled, assume single-process single-device
# SyncBatchNorm only supports DDP with single-process single-device anyway'
# SyncBatchNorm only supports DDP with single-process single-device anyway'
layer
.
_specify_ddp_gpu_num
(
1
)
layer
.
_specify_ddp_gpu_num
(
1
)
# type: ignore
tests/nn/data_parallel/test_sharded_ddp.py
View file @
84a3bdbe
...
@@ -316,8 +316,7 @@ def test_ddp_attributes():
...
@@ -316,8 +316,7 @@ def test_ddp_attributes():
# - device_type
# - device_type
url
=
"file://"
+
tempfile
.
mkstemp
()[
1
]
url
=
"file://"
+
tempfile
.
mkstemp
()[
1
]
backend
=
dist
.
Backend
.
NCCL
dist
.
init_process_group
(
init_method
=
url
,
backend
=
"gloo"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
0
,
world_size
=
1
)
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
model
=
Sequential
(
Linear
(
2
,
3
),
Linear
(
3
,
3
))
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
0.01
,
momentum
=
0.99
)
optimizer
=
OSS
(
params
=
model
.
parameters
(),
optim
=
torch
.
optim
.
SGD
,
lr
=
0.01
,
momentum
=
0.99
)
...
@@ -352,10 +351,7 @@ def test_ddp_sync_batch_norm():
...
@@ -352,10 +351,7 @@ def test_ddp_sync_batch_norm():
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
device
=
"cuda"
device
=
"cuda"
mp
.
spawn
(
mp
.
spawn
(
run_test_ddp_sync_batch_norm
,
run_test_ddp_sync_batch_norm
,
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
args
=
(
world_size
,
backend
,
device
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment