Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
9f7c3ec6
Commit
9f7c3ec6
authored
Dec 21, 2017
by
Myle Ott
Browse files
Add support for sharded generation
parent
cc7705d3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
1 deletion
+28
-1
fairseq/data.py
fairseq/data.py
+17
-0
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+3
-1
generate.py
generate.py
+8
-0
No files found.
fairseq/data.py
View file @
9f7c3ec6
...
@@ -175,6 +175,23 @@ def skip_group_enumerator(it, ngpus, offset=0):
...
@@ -175,6 +175,23 @@ def skip_group_enumerator(it, ngpus, offset=0):
yield
(
idx
,
res
)
yield
(
idx
,
res
)
class
sharded_iterator
(
object
):
def
__init__
(
self
,
itr
,
num_shards
,
shard_id
):
assert
shard_id
>=
0
and
shard_id
<
num_shards
self
.
itr
=
itr
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
def
__len__
(
self
):
return
len
(
self
.
itr
)
def
__iter__
(
self
):
for
i
,
v
in
enumerate
(
self
.
itr
):
if
i
%
self
.
num_shards
==
self
.
shard_id
:
yield
v
class
LanguagePairDataset
(
object
):
class
LanguagePairDataset
(
object
):
# padding constants
# padding constants
...
...
fairseq/multiprocessing_trainer.py
View file @
9f7c3ec6
...
@@ -15,7 +15,7 @@ import math
...
@@ -15,7 +15,7 @@ import math
import
torch
import
torch
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
torch.optim.lr_scheduler
import
LambdaLR
,
ReduceLROnPlateau
from
fairseq
import
meters
,
nccl
,
utils
from
fairseq
import
nccl
,
utils
from
fairseq.multiprocessing_event_loop
import
MultiprocessingEventLoop
,
Future
from
fairseq.multiprocessing_event_loop
import
MultiprocessingEventLoop
,
Future
from
fairseq.nag
import
NAG
from
fairseq.nag
import
NAG
...
@@ -116,6 +116,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -116,6 +116,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
def
_build_lr_scheduler
(
self
):
def
_build_lr_scheduler
(
self
):
if
len
(
self
.
args
.
lr
)
>
1
or
self
.
args
.
force_anneal
>
0
:
if
len
(
self
.
args
.
lr
)
>
1
or
self
.
args
.
force_anneal
>
0
:
lrs
=
self
.
args
.
lr
lrs
=
self
.
args
.
lr
def
anneal
(
e
):
def
anneal
(
e
):
if
e
<
self
.
args
.
force_anneal
:
if
e
<
self
.
args
.
force_anneal
:
# use fixed LR schedule
# use fixed LR schedule
...
@@ -123,6 +124,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -123,6 +124,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
else
:
else
:
next_lr
=
lrs
[
-
1
]
*
self
.
args
.
lrshrink
**
(
e
+
1
-
self
.
args
.
force_anneal
)
next_lr
=
lrs
[
-
1
]
*
self
.
args
.
lrshrink
**
(
e
+
1
-
self
.
args
.
force_anneal
)
return
next_lr
/
lrs
[
0
]
# correct for scaling from LambdaLR
return
next_lr
/
lrs
[
0
]
# correct for scaling from LambdaLR
lr_scheduler
=
LambdaLR
(
self
.
optimizer
,
anneal
)
lr_scheduler
=
LambdaLR
(
self
.
optimizer
,
anneal
)
lr_scheduler
.
best
=
None
lr_scheduler
.
best
=
None
else
:
else
:
...
...
generate.py
View file @
9f7c3ec6
...
@@ -23,6 +23,10 @@ def main():
...
@@ -23,6 +23,10 @@ def main():
help
=
'batch size'
)
help
=
'batch size'
)
dataset_args
.
add_argument
(
'--gen-subset'
,
default
=
'test'
,
metavar
=
'SPLIT'
,
dataset_args
.
add_argument
(
'--gen-subset'
,
default
=
'test'
,
metavar
=
'SPLIT'
,
help
=
'data subset to generate (train, valid, test)'
)
help
=
'data subset to generate (train, valid, test)'
)
dataset_args
.
add_argument
(
'--num-shards'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'shard generation over N shards'
)
dataset_args
.
add_argument
(
'--shard-id'
,
default
=
0
,
type
=
int
,
metavar
=
'ID'
,
help
=
'id of the shard to generate (id < num_shards)'
)
options
.
add_generation_args
(
parser
)
options
.
add_generation_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -72,6 +76,10 @@ def main():
...
@@ -72,6 +76,10 @@ def main():
itr
=
dataset
.
eval_dataloader
(
itr
=
dataset
.
eval_dataloader
(
args
.
gen_subset
,
max_sentences
=
args
.
batch_size
,
max_positions
=
max_positions
,
args
.
gen_subset
,
max_sentences
=
args
.
batch_size
,
max_positions
=
max_positions
,
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
skip_invalid_size_inputs_valid_test
=
args
.
skip_invalid_size_inputs_valid_test
)
if
args
.
num_shards
>
1
:
if
args
.
shard_id
<
0
or
args
.
shard_id
>=
args
.
num_shards
:
raise
ValueError
(
'--shard-id must be between 0 and num_shards'
)
itr
=
data
.
sharded_iterator
(
itr
,
args
.
num_shards
,
args
.
shard_id
)
num_sentences
=
0
num_sentences
=
0
with
utils
.
build_progress_bar
(
args
,
itr
)
as
t
:
with
utils
.
build_progress_bar
(
args
,
itr
)
as
t
:
wps_meter
=
TimeMeter
()
wps_meter
=
TimeMeter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment