Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
fbe8ce65
Commit
fbe8ce65
authored
Sep 17, 2018
by
Myle Ott
Browse files
Better support for various c10d API changes
parent
78071e0f
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
77 additions
and
38 deletions
+77
-38
fairseq/criterions/adaptive_loss.py
fairseq/criterions/adaptive_loss.py
+4
-3
fairseq/distributed_utils.py
fairseq/distributed_utils.py
+30
-20
fairseq/models/distributed_fairseq_model.py
fairseq/models/distributed_fairseq_model.py
+22
-6
fairseq/options.py
fairseq/options.py
+5
-4
fairseq/trainer.py
fairseq/trainer.py
+15
-4
tests/test_binaries.py
tests/test_binaries.py
+1
-1
No files found.
fairseq/criterions/adaptive_loss.py
View file @
fbe8ce65
...
@@ -22,10 +22,11 @@ class AdaptiveLoss(FairseqCriterion):
...
@@ -22,10 +22,11 @@ class AdaptiveLoss(FairseqCriterion):
def
__init__
(
self
,
args
,
task
):
def
__init__
(
self
,
args
,
task
):
super
().
__init__
(
args
,
task
)
super
().
__init__
(
args
,
task
)
if
not
args
.
no_
c10d
:
if
args
.
ddp_backend
==
'
c10d
'
:
raise
Exception
(
raise
Exception
(
'AdaptiveLoss is not compatible with the c10d version of '
'AdaptiveLoss is not compatible with the c10d '
'DistributedDataParallel. Please add the `--no-c10d` flag.'
'version of DistributedDataParallel. Please use '
'`--ddp-backend=no_c10d` instead.'
)
)
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
...
...
fairseq/distributed_utils.py
View file @
fbe8ce65
...
@@ -5,10 +5,11 @@
...
@@ -5,10 +5,11 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
from
collections
import
namedtuple
import
pickle
import
pickle
import
torch
import
torch
from
torch
import
distributed
from
torch
import
distributed
,
nn
from
torch.distributed
import
group
from
torch.distributed
import
group
from
fairseq
import
utils
from
fairseq
import
utils
...
@@ -18,33 +19,42 @@ def is_master(args):
...
@@ -18,33 +19,42 @@ def is_master(args):
return
args
.
distributed_rank
==
0
return
args
.
distributed_rank
==
0
_use_c10d
=
[
None
]
_use_c10d
=
[
True
]
C10dStatus
=
namedtuple
(
'C10dStatus'
,
[
'has_c10d'
,
'is_default'
])
if
hasattr
(
nn
.
parallel
,
'deprecated'
):
c10d_status
=
C10dStatus
(
has_c10d
=
True
,
is_default
=
True
)
elif
hasattr
(
nn
.
parallel
,
'_DistributedDataParallelC10d'
):
c10d_status
=
C10dStatus
(
has_c10d
=
True
,
is_default
=
False
)
else
:
c10d_status
=
C10dStatus
(
has_c10d
=
False
,
is_default
=
False
)
def
distributed_init
(
args
):
def
distributed_init
(
args
):
if
args
.
distributed_world_size
==
1
:
if
args
.
distributed_world_size
==
1
:
raise
ValueError
(
'Cannot initialize distributed with distributed_world_size=1'
)
raise
ValueError
(
'Cannot initialize distributed with distributed_world_size=1'
)
if
_use_c10d
[
0
]
is
None
:
if
args
.
ddp_backend
==
'no_c10d'
:
_use_c10d
[
0
]
=
not
args
.
no_c10d
if
_use_c10d
[
0
]
and
not
hasattr
(
torch
.
nn
.
parallel
,
'_DistributedDataParallelC10d'
):
_use_c10d
[
0
]
=
False
_use_c10d
[
0
]
=
False
print
(
'WARNING: cannot find DistributedDataParallelC10d, '
'falling back to standard DistributedDataParallel'
)
print
(
'| distributed init (rank {}): {}'
.
format
(
print
(
'| distributed init (rank {}): {}'
.
format
(
args
.
distributed_rank
,
args
.
distributed_init_method
),
flush
=
True
)
args
.
distributed_rank
,
args
.
distributed_init_method
),
flush
=
True
)
if
_use_c10d
[
0
]:
if
_use_c10d
[
0
]:
distributed
.
c10d
.
init_process_group
(
if
c10d_status
.
is_default
:
backend
=
args
.
distributed_backend
,
init_fn
=
distributed
.
init_process_group
init_method
=
args
.
distributed_init_method
,
world_size
=
args
.
distributed_world_size
,
rank
=
args
.
distributed_rank
,
)
else
:
else
:
distributed
.
init_process_group
(
init_fn
=
distributed
.
c10d
.
init_process_group
else
:
if
c10d_status
.
is_default
:
init_fn
=
distributed
.
deprecated
.
init_process_group
else
:
init_fn
=
distributed
.
init_process_group
init_fn
(
backend
=
args
.
distributed_backend
,
backend
=
args
.
distributed_backend
,
init_method
=
args
.
distributed_init_method
,
init_method
=
args
.
distributed_init_method
,
world_size
=
args
.
distributed_world_size
,
world_size
=
args
.
distributed_world_size
,
...
...
fairseq/models/distributed_fairseq_model.py
View file @
fbe8ce65
...
@@ -5,9 +5,10 @@
...
@@ -5,9 +5,10 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
from
torch.distributed
import
c10d
from
torch.nn
import
parallel
from
torch.nn
import
parallel
from
fairseq.distributed_utils
import
c10d_status
from
.
import
BaseFairseqModel
from
.
import
BaseFairseqModel
...
@@ -28,21 +29,36 @@ class DistributedFairseqModel(BaseFairseqModel):
...
@@ -28,21 +29,36 @@ class DistributedFairseqModel(BaseFairseqModel):
def
__init__
(
self
,
args
,
model
):
def
__init__
(
self
,
args
,
model
):
super
().
__init__
()
super
().
__init__
()
assert
isinstance
(
model
,
BaseFairseqModel
)
assert
isinstance
(
model
,
BaseFairseqModel
)
if
args
.
no_c10d
:
if
args
.
ddp_backend
==
'c10d'
:
self
.
ddp_model
=
parallel
.
DistributedDataParallel
(
if
c10d_status
.
is_default
:
ddp_class
=
parallel
.
DistributedDataParallel
elif
c10d_status
.
has_c10d
:
ddp_class
=
parallel
.
_DistributedDataParallelC10d
else
:
raise
Exception
(
'Can
\'
t find c10d version of DistributedDataParallel. '
'Please update PyTorch.'
)
self
.
ddp_model
=
ddp_class
(
module
=
model
,
module
=
model
,
device_ids
=
[
args
.
device_id
],
device_ids
=
[
args
.
device_id
],
output_device
=
args
.
device_id
,
output_device
=
args
.
device_id
,
broadcast_buffers
=
False
,
broadcast_buffers
=
False
,
bucket_cap_mb
=
args
.
bucket_cap_mb
,
)
)
elif
args
.
ddp_backend
==
'no_c10d'
:
if
c10d_status
.
is_default
:
ddp_class
=
parallel
.
deprecated
.
DistributedDataParallel
else
:
else
:
self
.
ddp_model
=
parallel
.
_DistributedDataParallelC10d
(
ddp_class
=
parallel
.
DistributedDataParallel
self
.
ddp_model
=
ddp_class
(
module
=
model
,
module
=
model
,
device_ids
=
[
args
.
device_id
],
device_ids
=
[
args
.
device_id
],
output_device
=
args
.
device_id
,
output_device
=
args
.
device_id
,
broadcast_buffers
=
False
,
broadcast_buffers
=
False
,
bucket_cap_mb
=
args
.
c10d_bucket_cap_mb
,
)
)
else
:
raise
ValueError
(
'Unknown --ddp-backend: '
+
args
.
ddp_backend
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
ddp_model
(
*
args
,
**
kwargs
)
return
self
.
ddp_model
(
*
args
,
**
kwargs
)
...
...
fairseq/options.py
View file @
fbe8ce65
...
@@ -185,10 +185,11 @@ def add_distributed_training_args(parser):
...
@@ -185,10 +185,11 @@ def add_distributed_training_args(parser):
help
=
'port number (not required if using --distributed-init-method)'
)
help
=
'port number (not required if using --distributed-init-method)'
)
group
.
add_argument
(
'--device-id'
,
default
=
0
,
type
=
int
,
group
.
add_argument
(
'--device-id'
,
default
=
0
,
type
=
int
,
help
=
'which GPU to use (usually configured automatically)'
)
help
=
'which GPU to use (usually configured automatically)'
)
group
.
add_argument
(
'--no-c10d'
,
action
=
'store_true'
,
group
.
add_argument
(
'--ddp-backend'
,
default
=
'c10d'
,
type
=
str
,
help
=
'don
\'
t use c10d distributed backend'
)
choices
=
[
'c10d'
,
'no_c10d'
],
group
.
add_argument
(
'--c10d-bucket-cap-mb'
,
default
=
150
,
type
=
int
,
metavar
=
'MB'
,
help
=
'DistributedDataParallel backend'
)
help
=
'bucket size for c10d backend'
)
group
.
add_argument
(
'--bucket-cap-mb'
,
default
=
150
,
type
=
int
,
metavar
=
'MB'
,
help
=
'bucket size for reduction'
)
return
group
return
group
...
...
fairseq/trainer.py
View file @
fbe8ce65
...
@@ -265,7 +265,7 @@ class Trainer(object):
...
@@ -265,7 +265,7 @@ class Trainer(object):
return
logging_output
return
logging_output
def
valid_step
(
self
,
sample
):
def
valid_step
(
self
,
sample
,
raise_oom
=
False
):
"""Do forward pass in evaluation mode."""
"""Do forward pass in evaluation mode."""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
model
.
eval
()
self
.
model
.
eval
()
...
@@ -277,9 +277,20 @@ class Trainer(object):
...
@@ -277,9 +277,20 @@ class Trainer(object):
else
:
else
:
ignore_results
=
False
ignore_results
=
False
try
:
_loss
,
sample_size
,
logging_output
=
self
.
task
.
get_loss
(
_loss
,
sample_size
,
logging_output
=
self
.
task
.
get_loss
(
self
.
model
,
self
.
criterion
,
sample
,
self
.
model
,
self
.
criterion
,
sample
,
)
)
except
RuntimeError
as
e
:
if
'out of memory'
in
str
(
e
)
and
not
raise_oom
:
print
(
'| WARNING: ran out of memory, retrying batch'
)
for
p
in
self
.
model
.
parameters
():
if
p
.
grad
is
not
None
:
del
p
.
grad
# free some memory
torch
.
cuda
.
empty_cache
()
return
self
.
valid_step
(
sample
,
raise_oom
=
True
)
else
:
raise
e
if
ignore_results
:
if
ignore_results
:
logging_output
,
sample_size
=
{},
0
logging_output
,
sample_size
=
{},
0
...
...
tests/test_binaries.py
View file @
fbe8ce65
...
@@ -292,7 +292,7 @@ def train_language_model(data_dir, arch):
...
@@ -292,7 +292,7 @@ def train_language_model(data_dir, arch):
'--max-epoch'
,
'1'
,
'--max-epoch'
,
'1'
,
'--no-progress-bar'
,
'--no-progress-bar'
,
'--distributed-world-size'
,
'1'
,
'--distributed-world-size'
,
'1'
,
'--no
-
c10d'
,
'--
ddp-backend'
,
'
no
_
c10d'
,
],
],
)
)
train
.
main
(
train_args
)
train
.
main
(
train_args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment