Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
2e507d3c
Commit
2e507d3c
authored
Aug 30, 2018
by
Myle Ott
Browse files
Clean up FairseqTask so that it's easier to extend/add new tasks
parent
6296de82
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
372 additions
and
136 deletions
+372
-136
eval_lm.py
eval_lm.py
+4
-2
fairseq/data/data_utils.py
fairseq/data/data_utils.py
+137
-78
fairseq/data/fairseq_dataset.py
fairseq/data/fairseq_dataset.py
+18
-6
fairseq/data/language_pair_dataset.py
fairseq/data/language_pair_dataset.py
+63
-21
fairseq/data/monolingual_dataset.py
fairseq/data/monolingual_dataset.py
+41
-11
fairseq/tasks/fairseq_task.py
fairseq/tasks/fairseq_task.py
+63
-2
fairseq/tasks/translation.py
fairseq/tasks/translation.py
+3
-0
fairseq/utils.py
fairseq/utils.py
+15
-1
generate.py
generate.py
+5
-2
interactive.py
interactive.py
+10
-6
tests/test_train.py
tests/test_train.py
+1
-1
train.py
train.py
+12
-6
No files found.
eval_lm.py
View file @
2e507d3c
...
...
@@ -59,11 +59,13 @@ def main(parsed_args):
assert
len
(
models
)
>
0
itr
=
data
.
EpochB
atch
I
terator
(
itr
=
task
.
get_b
atch
_i
terator
(
dataset
=
task
.
dataset
(
args
.
gen_subset
),
max_tokens
=
args
.
max_tokens
or
36000
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
models
[
0
].
max_positions
(),
max_positions
=
utils
.
resolve_max_positions
(
*
[
model
.
max_positions
()
for
model
in
models
]),
num_shards
=
args
.
num_shards
,
shard_id
=
args
.
shard_id
,
ignore_invalid_inputs
=
True
,
...
...
fairseq/data/data_utils.py
View file @
2e507d3c
...
...
@@ -12,8 +12,6 @@ import os
import
numpy
as
np
import
torch
from
.
import
FairseqDataset
def
infer_language_pair
(
path
):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
...
...
@@ -99,42 +97,35 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
class
EpochBatchIterator
(
object
):
"""Iterate over a FairseqDataset and yield batches bucketed by size.
"""A multi-epoch iterator over a :class:`~torch.utils.data.Dataset`.
Compared to :class:`~torch.utils.data.DataLoader`, this iterator:
Batches may contain sequences of different lengths. This iterator can be
reused across multiple epochs with the next_epoch_itr() method.
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the ``num_shards`` and ``shard_id`` arguments
Args:
dataset: a FairseqDataset
max_tokens: max number of tokens in each batch
max_sentences: max number of sentences in each batch
max_positions: max sentence length supported by the model
ignore_invalid_inputs: don't raise Exception for sentences that are too long
required_batch_size_multiple: require batch size to be a multiple of N
seed: seed for random number generator for reproducibility
num_shards: shard the data iterator into N shards
shard_id: which shard of the data iterator to return
dataset (Dataset): dataset from which to load the data
batch_sampler (Sampler): an iterator over batches of indices
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
"""
def
__init__
(
self
,
dataset
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
None
,
ignore_invalid_inputs
=
False
,
required_batch_size_multiple
=
1
,
seed
=
1
,
num_shards
=
1
,
shard_id
=
0
,
):
assert
isinstance
(
dataset
,
FairseqDataset
)
def
__init__
(
self
,
dataset
,
batch_sampler
,
seed
=
1
,
num_shards
=
1
,
shard_id
=
0
):
assert
isinstance
(
dataset
,
torch
.
utils
.
data
.
Dataset
)
self
.
dataset
=
dataset
self
.
max_tokens
=
max_tokens
if
max_tokens
is
not
None
else
float
(
'Inf'
)
self
.
max_sentences
=
max_sentences
if
max_sentences
is
not
None
else
float
(
'Inf'
)
self
.
max_positions
=
max_positions
self
.
ignore_invalid_inputs
=
ignore_invalid_inputs
self
.
bsz_mult
=
required_batch_size_multiple
self
.
frozen_batches
=
tuple
(
batch_sampler
)
self
.
seed
=
seed
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
with
numpy_seed
(
self
.
seed
):
self
.
frozen_batches
=
tuple
(
self
.
_batch_generator
())
self
.
epoch
=
0
self
.
_cur_epoch_itr
=
None
self
.
_next_epoch_itr
=
None
...
...
@@ -143,7 +134,13 @@ class EpochBatchIterator(object):
return
len
(
self
.
frozen_batches
)
def
next_epoch_itr
(
self
,
shuffle
=
True
):
"""Shuffle batches and return a new iterator over the dataset."""
"""
Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
"""
if
self
.
_next_epoch_itr
is
not
None
:
self
.
_cur_epoch_itr
=
self
.
_next_epoch_itr
self
.
_next_epoch_itr
=
None
...
...
@@ -153,10 +150,12 @@ class EpochBatchIterator(object):
return
self
.
_cur_epoch_itr
def
end_of_epoch
(
self
):
"""Returns whether the most recent epoch iterator has been exhausted"""
return
not
self
.
_cur_epoch_itr
.
has_next
()
@
property
def
iterations_in_epoch
(
self
):
"""The number of consumed batches in the current epoch."""
if
self
.
_cur_epoch_itr
is
not
None
:
return
self
.
_cur_epoch_itr
.
count
elif
self
.
_next_epoch_itr
is
not
None
:
...
...
@@ -193,55 +192,6 @@ class EpochBatchIterator(object):
batch_sampler
=
ShardedIterator
(
batches
,
self
.
num_shards
,
self
.
shard_id
,
fill_value
=
[]),
))
def
_batch_generator
(
self
):
batch
=
[]
def
is_batch_full
(
num_tokens
):
if
len
(
batch
)
==
0
:
return
False
if
len
(
batch
)
==
self
.
max_sentences
:
return
True
if
num_tokens
>
self
.
max_tokens
:
return
True
return
False
sample_len
=
0
sample_lens
=
[]
ignored
=
[]
for
idx
in
self
.
dataset
.
ordered_indices
():
if
not
self
.
dataset
.
valid_size
(
idx
,
self
.
max_positions
):
if
self
.
ignore_invalid_inputs
:
ignored
.
append
(
idx
)
continue
raise
Exception
((
'Size of sample #{} is invalid, max_positions={}, skip this '
'example with --skip-invalid-size-inputs-valid-test'
).
format
(
idx
,
self
.
max_positions
))
sample_lens
.
append
(
self
.
dataset
.
num_tokens
(
idx
))
sample_len
=
max
(
sample_len
,
sample_lens
[
-
1
])
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
if
is_batch_full
(
num_tokens
):
mod_len
=
max
(
self
.
bsz_mult
*
(
len
(
batch
)
//
self
.
bsz_mult
),
len
(
batch
)
%
self
.
bsz_mult
,
)
yield
batch
[:
mod_len
]
batch
=
batch
[
mod_len
:]
sample_lens
=
sample_lens
[
mod_len
:]
sample_len
=
max
(
sample_lens
)
if
len
(
sample_lens
)
>
0
else
0
batch
.
append
(
idx
)
if
len
(
batch
)
>
0
:
yield
batch
if
len
(
ignored
)
>
0
:
print
((
'| WARNING: {} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).
format
(
len
(
ignored
),
self
.
max_positions
,
ignored
[:
10
]))
@
contextlib
.
contextmanager
def
numpy_seed
(
seed
):
...
...
@@ -256,3 +206,112 @@ def numpy_seed(seed):
yield
finally
:
np
.
random
.
set_state
(
state
)
def
collect_filtered
(
function
,
iterable
,
filtered
):
"""
Similar to :func:`filter` but collects filtered elements in ``filtered``.
Args:
function (callable): function that returns ``False`` for elements that
should be filtered
iterable (iterable): iterable to filter
filtered (list): list to store filtered elements
"""
for
el
in
iterable
:
if
function
(
el
):
yield
el
else
:
filtered
.
append
(
el
)
def
filter_by_size
(
indices
,
size_fn
,
max_positions
,
raise_exception
=
False
):
"""
Filter indices based on their size.
Args:
indices (List[int]): ordered list of dataset indices
size_fn (callable): function that returns the size of a given index
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception
if any elements are filtered. Default: ``False``
"""
def
check_size
(
idx
):
if
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
):
return
size_fn
(
idx
)
<
max_positions
else
:
return
all
(
a
<=
b
for
a
,
b
in
zip
(
size_fn
(
idx
),
max_positions
))
ignored
=
[]
itr
=
collect_filtered
(
check_size
,
indices
,
ignored
)
for
idx
in
itr
:
if
len
(
ignored
)
>
0
and
raise_exception
:
raise
Exception
((
'Size of sample #{} is invalid (={}) since max_positions={}, '
'skip this example with --skip-invalid-size-inputs-valid-test'
).
format
(
idx
,
self
.
size
(
idx
),
max_positions
))
yield
idx
if
len
(
ignored
)
>
0
:
print
((
'| WARNING: {} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).
format
(
len
(
ignored
),
max_positions
,
ignored
[:
10
]))
def
batch_by_size
(
indices
,
num_tokens_fn
,
max_tokens
=
None
,
max_sentences
=
None
,
required_batch_size_multiple
=
1
,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
"""
max_tokens
=
max_tokens
if
max_tokens
is
not
None
else
float
(
'Inf'
)
max_sentences
=
max_sentences
if
max_sentences
is
not
None
else
float
(
'Inf'
)
bsz_mult
=
required_batch_size_multiple
batch
=
[]
def
is_batch_full
(
num_tokens
):
if
len
(
batch
)
==
0
:
return
False
if
len
(
batch
)
==
max_sentences
:
return
True
if
num_tokens
>
max_tokens
:
return
True
return
False
sample_len
=
0
sample_lens
=
[]
ignored
=
[]
for
idx
in
indices
:
sample_lens
.
append
(
num_tokens_fn
(
idx
))
sample_len
=
max
(
sample_len
,
sample_lens
[
-
1
])
num_tokens
=
(
len
(
batch
)
+
1
)
*
sample_len
if
is_batch_full
(
num_tokens
):
mod_len
=
max
(
bsz_mult
*
(
len
(
batch
)
//
bsz_mult
),
len
(
batch
)
%
bsz_mult
,
)
yield
batch
[:
mod_len
]
batch
=
batch
[
mod_len
:]
sample_lens
=
sample_lens
[
mod_len
:]
sample_len
=
max
(
sample_lens
)
if
len
(
sample_lens
)
>
0
else
0
batch
.
append
(
idx
)
if
len
(
batch
)
>
0
:
yield
batch
fairseq/data/fairseq_dataset.py
View file @
2e507d3c
...
...
@@ -7,6 +7,8 @@
import
torch.utils.data
from
fairseq.data
import
data_utils
class
FairseqDataset
(
torch
.
utils
.
data
.
Dataset
):
"""A dataset that provides helpers for batching."""
...
...
@@ -18,7 +20,14 @@ class FairseqDataset(torch.utils.data.Dataset):
raise
NotImplementedError
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch."""
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[int]): sample indices to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
raise
NotImplementedError
def
get_dummy_batch
(
self
,
num_tokens
,
max_positions
):
...
...
@@ -26,13 +35,16 @@ class FairseqDataset(torch.utils.data.Dataset):
raise
NotImplementedError
def
num_tokens
(
self
,
index
):
"""Return an example's length (number of tokens), used for batching."""
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
raise
NotImplementedError
def
ordered_indices
(
self
):
"""Ordered indices for batching."""
def
size
(
self
,
index
):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
raise
NotImplementedError
def
valid_size
(
self
,
index
,
max_positions
):
"""Check if an example's size is valid according to max_positions."""
def
ordered_indices
(
self
):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
raise
NotImplementedError
fairseq/data/language_pair_dataset.py
View file @
2e507d3c
...
...
@@ -8,6 +8,8 @@
import
numpy
as
np
import
torch
from
fairseq
import
utils
from
.
import
data_utils
,
FairseqDataset
...
...
@@ -59,7 +61,27 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
class
LanguagePairDataset
(
FairseqDataset
):
"""A pair of torch.utils.data.Datasets."""
"""
A pair of torch.utils.data.Datasets.
Args:
src (torch.utils.data.Dataset): source dataset to wrap
src_sizes (List[int]): source sentence lengths
src_dict (fairseq.data.Dictionary): source vocabulary
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side.
Default: ``True``
left_pad_target (bool, optional): pad target tensors on the left side.
Default: ``False``
max_source_positions (int, optional): max number of tokens in the source
sentence. Default: ``1024``
max_target_positions (int, optional): max number of tokens in the target
sentence. Default: ``1024``
shuffle (bool, optional): shuffle dataset elements before batching.
Default: ``True``
"""
def
__init__
(
self
,
src
,
src_sizes
,
src_dict
,
...
...
@@ -95,15 +117,43 @@ class LanguagePairDataset(FairseqDataset):
return
len
(
self
.
src
)
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch."""
"""Merge a list of samples to form a mini-batch.
Returned mini-batches contain the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear
on the left if ``left_pad_source`` is True.
- `src_lengths` (torch.LongTensor): 1D Tensor of the unpadded lengths
of each source sentence of shape `(bsz)`
- `prev_output_tokens` (torch.LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one position for
input feeding/teacher forcing, of shape `(bsz, tgt_len)`. Padding
will appear on the left if ``left_pad_target`` is True.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
left if ``left_pad_target`` is True.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
return
collate
(
samples
,
pad_idx
=
self
.
src_dict
.
pad
(),
eos_idx
=
self
.
src_dict
.
eos
(),
left_pad_source
=
self
.
left_pad_source
,
left_pad_target
=
self
.
left_pad_target
,
)
def
get_dummy_batch
(
self
,
num_tokens
,
max_positions
,
src_len
=
128
,
tgt_len
=
128
):
max_source_positions
,
max_target_positions
=
self
.
_get_max_positions
(
max_positions
)
src_len
,
tgt_len
=
min
(
src_len
,
max_source_positions
),
min
(
tgt_len
,
max_target_positions
)
"""Return a dummy batch with a given number of tokens."""
src_len
,
tgt_len
=
utils
.
resolve_max_positions
(
(
src_len
,
tgt_len
),
max_positions
,
(
self
.
max_source_positions
,
self
.
max_target_positions
),
)
bsz
=
num_tokens
//
max
(
src_len
,
tgt_len
)
return
self
.
collater
([
{
...
...
@@ -115,11 +165,18 @@ class LanguagePairDataset(FairseqDataset):
])
def
num_tokens
(
self
,
index
):
"""Return an example's length (number of tokens), used for batching."""
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return
max
(
self
.
src_sizes
[
index
],
self
.
tgt_sizes
[
index
]
if
self
.
tgt_sizes
is
not
None
else
0
)
def
size
(
self
,
index
):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return
(
self
.
src_sizes
[
index
],
self
.
tgt_sizes
[
index
]
if
self
.
tgt_sizes
is
not
None
else
0
)
def
ordered_indices
(
self
):
"""Ordered indices for batching."""
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if
self
.
shuffle
:
indices
=
np
.
random
.
permutation
(
len
(
self
))
else
:
...
...
@@ -127,18 +184,3 @@ class LanguagePairDataset(FairseqDataset):
if
self
.
tgt_sizes
is
not
None
:
indices
=
indices
[
np
.
argsort
(
self
.
tgt_sizes
[
indices
],
kind
=
'mergesort'
)]
return
indices
[
np
.
argsort
(
self
.
src_sizes
[
indices
],
kind
=
'mergesort'
)]
def
valid_size
(
self
,
index
,
max_positions
):
"""Check if an example's size is valid according to max_positions."""
max_source_positions
,
max_target_positions
=
self
.
_get_max_positions
(
max_positions
)
return
(
self
.
src_sizes
[
index
]
<=
max_source_positions
and
(
self
.
tgt_sizes
is
None
or
self
.
tgt_sizes
[
index
]
<=
max_target_positions
)
)
def
_get_max_positions
(
self
,
max_positions
):
if
max_positions
is
None
:
return
self
.
max_source_positions
,
self
.
max_target_positions
assert
len
(
max_positions
)
==
2
max_src_pos
,
max_tgt_pos
=
max_positions
return
min
(
self
.
max_source_positions
,
max_src_pos
),
min
(
self
.
max_target_positions
,
max_tgt_pos
)
fairseq/data/monolingual_dataset.py
View file @
2e507d3c
...
...
@@ -31,7 +31,16 @@ def collate(samples, pad_idx, eos_idx):
class
MonolingualDataset
(
FairseqDataset
):
"""A wrapper around torch.utils.data.Dataset for monolingual data."""
"""
A wrapper around torch.utils.data.Dataset for monolingual data.
Args:
dataset (torch.utils.data.Dataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (fairseq.data.Dictionary): vocabulary
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
"""
def
__init__
(
self
,
dataset
,
sizes
,
vocab
,
shuffle
):
self
.
dataset
=
dataset
...
...
@@ -47,12 +56,31 @@ class MonolingualDataset(FairseqDataset):
return
len
(
self
.
dataset
)
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch."""
"""Merge a list of samples to form a mini-batch.
Returned mini-batches contain the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear
on the right.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
right.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
return
collate
(
samples
,
self
.
vocab
.
pad
(),
self
.
vocab
.
eos
())
def
get_dummy_batch
(
self
,
num_tokens
,
max_positions
,
tgt_len
=
128
):
assert
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
)
tgt_len
=
min
(
tgt_len
,
max_positions
)
"""Return a dummy batch with a given number of tokens."""
if
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
):
tgt_len
=
min
(
tgt_len
,
max_positions
)
bsz
=
num_tokens
//
tgt_len
target
=
self
.
vocab
.
dummy_sentence
(
tgt_len
+
1
)
source
,
target
=
target
[:
-
1
],
target
[
1
:]
...
...
@@ -62,19 +90,21 @@ class MonolingualDataset(FairseqDataset):
])
def
num_tokens
(
self
,
index
):
"""Return an example's length (number of tokens), used for batching."""
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return
self
.
sizes
[
index
]
def
size
(
self
,
index
):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return
self
.
sizes
[
index
]
def
ordered_indices
(
self
):
"""Ordered indices for batching."""
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if
self
.
shuffle
:
order
=
[
np
.
random
.
permutation
(
len
(
self
))]
else
:
order
=
[
np
.
arange
(
len
(
self
))]
order
.
append
(
np
.
flip
(
self
.
sizes
,
0
))
return
np
.
lexsort
(
order
)
def
valid_size
(
self
,
index
,
max_positions
):
"""Check if an example's size is valid according to max_positions."""
assert
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
)
return
self
.
sizes
[
index
]
<=
max_positions
fairseq/tasks/fairseq_task.py
View file @
2e507d3c
...
...
@@ -5,11 +5,13 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
fairseq.data
import
data_utils
,
FairseqDataset
class
FairseqTask
(
object
):
"""
A
Task
defines the data format, stores shared state (e.g., dictionaries) and
provides helpers for build
ing the
m
odel/
c
riterion and calculating the loss.
Task
s store dictionaries and provide helpers for loading/iterating over
Datasets, initializ
ing the
M
odel/
C
riterion and calculating the loss.
"""
@
staticmethod
...
...
@@ -37,6 +39,62 @@ class FairseqTask(object):
raise
TypeError
(
'Datasets are expected to be of type FairseqDataset'
)
return
self
.
datasets
[
split
]
def
get_batch_iterator
(
self
,
dataset
,
max_tokens
=
None
,
max_sentences
=
None
,
max_positions
=
None
,
ignore_invalid_inputs
=
False
,
required_batch_size_multiple
=
1
,
seed
=
1
,
num_shards
=
1
,
shard_id
=
0
,
):
"""
Generate batches of indices.
Args:
dataset (FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
max_positions (optional): max sentence length supported by the
model. Default: ``None``
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long. Default: ``False``
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
Returns:
EpochBatchIterator: a batched iterator over the given dataset split
"""
assert
isinstance
(
dataset
,
FairseqDataset
)
# get indices ordered by example size
with
data_utils
.
numpy_seed
(
seed
):
indices
=
dataset
.
ordered_indices
()
# filter examples that are too large
indices
=
data_utils
.
filter_by_size
(
indices
,
dataset
.
size
,
max_positions
,
raise_exception
=
(
not
ignore_invalid_inputs
),
)
# create mini-batches with given size constraints
batch_sampler
=
data_utils
.
batch_by_size
(
indices
,
dataset
.
num_tokens
,
max_tokens
=
max_tokens
,
max_sentences
=
max_sentences
,
required_batch_size_multiple
=
required_batch_size_multiple
,
)
# return a reusable, sharded iterator
return
data_utils
.
EpochBatchIterator
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
seed
=
seed
,
num_shards
=
num_shards
,
shard_id
=
shard_id
,
)
def
build_model
(
self
,
args
):
from
fairseq
import
models
return
models
.
build_model
(
args
,
self
)
...
...
@@ -48,6 +106,9 @@ class FairseqTask(object):
def
get_loss
(
self
,
model
,
criterion
,
sample
):
return
criterion
(
model
,
sample
)
def
max_positions
(
self
):
return
None
@
property
def
source_dictionary
(
self
):
raise
NotImplementedError
...
...
fairseq/tasks/translation.py
View file @
2e507d3c
...
...
@@ -139,6 +139,9 @@ class TranslationTask(FairseqTask):
max_target_positions
=
self
.
args
.
max_target_positions
,
)
def
max_positions
(
self
):
return
(
self
.
args
.
max_source_positions
,
self
.
args
.
max_target_positions
)
@
property
def
source_dictionary
(
self
):
return
self
.
src_dict
...
...
fairseq/utils.py
View file @
2e507d3c
...
...
@@ -150,7 +150,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
ensemble
=
[]
for
state
in
states
:
args
=
state
[
'args'
]
if
model_arg_overrides
is
not
None
:
args
=
_override_model_args
(
args
,
model_arg_overrides
)
...
...
@@ -399,3 +399,17 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
idx
=
int
(
m
.
group
(
1
))
if
len
(
m
.
groups
())
>
0
else
i
entries
.
append
((
idx
,
m
.
group
(
0
)))
return
[
os
.
path
.
join
(
path
,
x
[
1
])
for
x
in
sorted
(
entries
,
reverse
=
True
)]
def
resolve_max_positions
(
*
args
):
"""Resolve max position constraints from multiple sources."""
max_positions
=
None
for
arg
in
args
:
if
max_positions
is
None
:
max_positions
=
arg
elif
arg
is
not
None
:
if
isinstance
(
arg
,
float
)
or
isinstance
(
arg
,
int
):
max_positions
=
min
(
max_positions
,
arg
)
else
:
max_positions
=
tuple
(
map
(
min
,
zip
(
max_positions
,
arg
)))
return
max_positions
generate.py
View file @
2e507d3c
...
...
@@ -54,11 +54,14 @@ def main(args):
align_dict
=
utils
.
load_align_dict
(
args
.
replace_unk
)
# Load dataset (possibly sharded)
itr
=
data
.
EpochB
atch
I
terator
(
itr
=
task
.
get_b
atch
_i
terator
(
dataset
=
task
.
dataset
(
args
.
gen_subset
),
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
models
[
0
].
max_positions
(),
max_positions
=
utils
.
resolve_max_positions
(
task
.
max_positions
(),
*
[
model
.
max_positions
()
for
model
in
models
]
),
ignore_invalid_inputs
=
args
.
skip_invalid_size_inputs_valid_test
,
required_batch_size_multiple
=
8
,
num_shards
=
args
.
num_shards
,
...
...
interactive.py
View file @
2e507d3c
...
...
@@ -32,14 +32,14 @@ def buffered_read(buffer_size):
yield
buffer
def
make_batches
(
lines
,
args
,
src_dict
,
max_positions
):
def
make_batches
(
lines
,
args
,
task
,
max_positions
):
tokens
=
[
tokenizer
.
Tokenizer
.
tokenize
(
src_str
,
s
rc_dict
,
add_if_not_exist
=
False
).
long
()
tokenizer
.
Tokenizer
.
tokenize
(
src_str
,
task
.
sou
rc
e
_dict
ionary
,
add_if_not_exist
=
False
).
long
()
for
src_str
in
lines
]
lengths
=
np
.
array
([
t
.
numel
()
for
t
in
tokens
])
itr
=
data
.
EpochB
atch
I
terator
(
dataset
=
data
.
LanguagePairDataset
(
tokens
,
lengths
,
s
rc_dict
),
itr
=
task
.
get_b
atch
_i
terator
(
dataset
=
data
.
LanguagePairDataset
(
tokens
,
lengths
,
task
.
sou
rc
e
_dict
ionary
),
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
max_positions
,
...
...
@@ -76,7 +76,6 @@ def main(args):
models
,
model_args
=
utils
.
load_ensemble_for_inference
(
model_paths
,
task
,
model_arg_overrides
=
eval
(
args
.
model_overrides
))
# Set dictionaries
src_dict
=
task
.
source_dictionary
tgt_dict
=
task
.
target_dictionary
# Optimize ensemble for generation
...
...
@@ -151,13 +150,18 @@ def main(args):
return
[
make_result
(
batch
.
srcs
[
i
],
t
)
for
i
,
t
in
enumerate
(
translations
)]
max_positions
=
utils
.
resolve_max_positions
(
task
.
max_positions
(),
*
[
model
.
max_positions
()
for
model
in
models
]
)
if
args
.
buffer_size
>
1
:
print
(
'| Sentence buffer size:'
,
args
.
buffer_size
)
print
(
'| Type the input sentence and press return:'
)
for
inputs
in
buffered_read
(
args
.
buffer_size
):
indices
=
[]
results
=
[]
for
batch
,
batch_indices
in
make_batches
(
inputs
,
args
,
src_dict
,
models
[
0
].
max_positions
()
):
for
batch
,
batch_indices
in
make_batches
(
inputs
,
args
,
task
,
max_positions
):
indices
.
extend
(
batch_indices
)
results
+=
process_batch
(
batch
)
...
...
tests/test_train.py
View file @
2e507d3c
...
...
@@ -44,7 +44,7 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
trainer
=
mock_trainer
(
epoch
,
num_updates
,
iterations_in_epoch
)
epoch_itr
=
data
.
EpochBatchIterator
(
dataset
=
data
.
LanguagePairDataset
(
tokens_ds
,
tokens_ds
.
sizes
,
mock_dict
(),
shuffle
=
False
),
max_tokens
=
1
,
batch_sampler
=
[[
i
]
for
i
in
range
(
epoch_size
)]
,
)
return
trainer
,
epoch_itr
...
...
train.py
View file @
2e507d3c
...
...
@@ -12,7 +12,7 @@ import os
import
math
import
torch
from
fairseq
import
data
,
distributed_utils
,
options
,
progress_bar
,
tasks
,
utils
from
fairseq
import
distributed_utils
,
options
,
progress_bar
,
tasks
,
utils
from
fairseq.fp16_trainer
import
FP16Trainer
from
fairseq.trainer
import
Trainer
from
fairseq.meters
import
AverageMeter
,
StopwatchMeter
...
...
@@ -57,11 +57,14 @@ def main(args):
))
# Initialize dataloader
max_positions
=
trainer
.
get_model
().
max_positions
()
epoch_itr
=
data
.
EpochBatchIterator
(
max_positions
=
utils
.
resolve_max_positions
(
task
.
max_positions
(),
trainer
.
get_model
().
max_positions
(),
)
epoch_itr
=
task
.
get_batch_iterator
(
dataset
=
task
.
dataset
(
args
.
train_subset
),
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences
_valid
,
max_sentences
=
args
.
max_sentences
,
max_positions
=
max_positions
,
ignore_invalid_inputs
=
True
,
required_batch_size_multiple
=
8
,
...
...
@@ -193,11 +196,14 @@ def validate(args, trainer, task, epoch_itr, subsets):
valid_losses
=
[]
for
subset
in
subsets
:
# Initialize data iterator
itr
=
data
.
EpochB
atch
I
terator
(
itr
=
task
.
get_b
atch
_i
terator
(
dataset
=
task
.
dataset
(
subset
),
max_tokens
=
args
.
max_tokens
,
max_sentences
=
args
.
max_sentences_valid
,
max_positions
=
trainer
.
get_model
().
max_positions
(),
max_positions
=
utils
.
resolve_max_positions
(
task
.
max_positions
(),
trainer
.
get_model
().
max_positions
(),
),
ignore_invalid_inputs
=
args
.
skip_invalid_size_inputs_valid_test
,
required_batch_size_multiple
=
8
,
seed
=
args
.
seed
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment