Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
0e101e9c
Commit
0e101e9c
authored
Sep 02, 2018
by
Myle Ott
Browse files
Misc changes to simplify upcoming tutorial
parent
d473620e
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
65 additions
and
27 deletions
+65
-27
eval_lm.py
eval_lm.py
+1
-1
fairseq/data/data_utils.py
fairseq/data/data_utils.py
+2
-1
fairseq/data/language_pair_dataset.py
fairseq/data/language_pair_dataset.py
+30
-17
fairseq/data/monolingual_dataset.py
fairseq/data/monolingual_dataset.py
+5
-2
fairseq/models/fairseq_model.py
fairseq/models/fairseq_model.py
+13
-4
fairseq/tasks/fairseq_task.py
fairseq/tasks/fairseq_task.py
+1
-1
fairseq/utils.py
fairseq/utils.py
+13
-1
No files found.
eval_lm.py
View file @
0e101e9c
...
@@ -109,7 +109,7 @@ def main(parsed_args):
...
@@ -109,7 +109,7 @@ def main(parsed_args):
print
(
'| Skipping tokens with inf scores:'
,
print
(
'| Skipping tokens with inf scores:'
,
task
.
target_dictionary
.
string
(
hypo
[
'tokens'
][
inf_scores
.
nonzero
()]))
task
.
target_dictionary
.
string
(
hypo
[
'tokens'
][
inf_scores
.
nonzero
()]))
pos_scores
=
pos_scores
[(
~
inf_scores
).
nonzero
()]
pos_scores
=
pos_scores
[(
~
inf_scores
).
nonzero
()]
score_sum
+=
pos_scores
.
sum
()
score_sum
+=
utils
.
item
(
pos_scores
.
sum
()
)
count
+=
pos_scores
.
numel
()
-
skipped_toks
count
+=
pos_scores
.
numel
()
-
skipped_toks
if
args
.
output_word_probs
or
args
.
output_word_stats
:
if
args
.
output_word_probs
or
args
.
output_word_stats
:
...
...
fairseq/data/data_utils.py
View file @
0e101e9c
...
@@ -88,7 +88,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
...
@@ -88,7 +88,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
if
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
):
if
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
):
return
size_fn
(
idx
)
<=
max_positions
return
size_fn
(
idx
)
<=
max_positions
else
:
else
:
return
all
(
a
<=
b
for
a
,
b
in
zip
(
size_fn
(
idx
),
max_positions
))
return
all
(
a
is
None
or
b
is
None
or
a
<=
b
for
a
,
b
in
zip
(
size_fn
(
idx
),
max_positions
))
ignored
=
[]
ignored
=
[]
itr
=
collect_filtered
(
check_size
,
indices
,
ignored
)
itr
=
collect_filtered
(
check_size
,
indices
,
ignored
)
...
...
fairseq/data/language_pair_dataset.py
View file @
0e101e9c
...
@@ -13,7 +13,10 @@ from fairseq import utils
...
@@ -13,7 +13,10 @@ from fairseq import utils
from
.
import
data_utils
,
FairseqDataset
from
.
import
data_utils
,
FairseqDataset
def
collate
(
samples
,
pad_idx
,
eos_idx
,
left_pad_source
=
True
,
left_pad_target
=
False
):
def
collate
(
samples
,
pad_idx
,
eos_idx
,
left_pad_source
=
True
,
left_pad_target
=
False
,
input_feeding
=
True
,
):
if
len
(
samples
)
==
0
:
if
len
(
samples
)
==
0
:
return
{}
return
{}
...
@@ -35,6 +38,10 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
...
@@ -35,6 +38,10 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
target
=
None
target
=
None
if
samples
[
0
].
get
(
'target'
,
None
)
is
not
None
:
if
samples
[
0
].
get
(
'target'
,
None
)
is
not
None
:
target
=
merge
(
'target'
,
left_pad
=
left_pad_target
)
target
=
merge
(
'target'
,
left_pad
=
left_pad_target
)
target
=
target
.
index_select
(
0
,
sort_order
)
ntokens
=
sum
(
len
(
s
[
'target'
])
for
s
in
samples
)
if
input_feeding
:
# we create a shifted version of targets for feeding the
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
# previous output token(s) into the next decoder step
prev_output_tokens
=
merge
(
prev_output_tokens
=
merge
(
...
@@ -43,21 +50,21 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
...
@@ -43,21 +50,21 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
move_eos_to_beginning
=
True
,
move_eos_to_beginning
=
True
,
)
)
prev_output_tokens
=
prev_output_tokens
.
index_select
(
0
,
sort_order
)
prev_output_tokens
=
prev_output_tokens
.
index_select
(
0
,
sort_order
)
target
=
target
.
index_select
(
0
,
sort_order
)
ntokens
=
sum
(
len
(
s
[
'target'
])
for
s
in
samples
)
else
:
else
:
ntokens
=
sum
(
len
(
s
[
'source'
])
for
s
in
samples
)
ntokens
=
sum
(
len
(
s
[
'source'
])
for
s
in
samples
)
return
{
batch
=
{
'id'
:
id
,
'id'
:
id
,
'ntokens'
:
ntokens
,
'ntokens'
:
ntokens
,
'net_input'
:
{
'net_input'
:
{
'src_tokens'
:
src_tokens
,
'src_tokens'
:
src_tokens
,
'src_lengths'
:
src_lengths
,
'src_lengths'
:
src_lengths
,
'prev_output_tokens'
:
prev_output_tokens
,
},
},
'target'
:
target
,
'target'
:
target
,
}
}
if
prev_output_tokens
is
not
None
:
batch
[
'net_input'
][
'prev_output_tokens'
]
=
prev_output_tokens
return
batch
class
LanguagePairDataset
(
FairseqDataset
):
class
LanguagePairDataset
(
FairseqDataset
):
...
@@ -81,6 +88,9 @@ class LanguagePairDataset(FairseqDataset):
...
@@ -81,6 +88,9 @@ class LanguagePairDataset(FairseqDataset):
sentence. Default: ``1024``
sentence. Default: ``1024``
shuffle (bool, optional): shuffle dataset elements before batching.
shuffle (bool, optional): shuffle dataset elements before batching.
Default: ``True``
Default: ``True``
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for input feeding/teacher forcing.
Default: ``True``
"""
"""
def
__init__
(
def
__init__
(
...
@@ -88,7 +98,7 @@ class LanguagePairDataset(FairseqDataset):
...
@@ -88,7 +98,7 @@ class LanguagePairDataset(FairseqDataset):
tgt
=
None
,
tgt_sizes
=
None
,
tgt_dict
=
None
,
tgt
=
None
,
tgt_sizes
=
None
,
tgt_dict
=
None
,
left_pad_source
=
True
,
left_pad_target
=
False
,
left_pad_source
=
True
,
left_pad_target
=
False
,
max_source_positions
=
1024
,
max_target_positions
=
1024
,
max_source_positions
=
1024
,
max_target_positions
=
1024
,
shuffle
=
True
,
shuffle
=
True
,
input_feeding
=
True
,
):
):
if
tgt_dict
is
not
None
:
if
tgt_dict
is
not
None
:
assert
src_dict
.
pad
()
==
tgt_dict
.
pad
()
assert
src_dict
.
pad
()
==
tgt_dict
.
pad
()
...
@@ -105,6 +115,7 @@ class LanguagePairDataset(FairseqDataset):
...
@@ -105,6 +115,7 @@ class LanguagePairDataset(FairseqDataset):
self
.
max_source_positions
=
max_source_positions
self
.
max_source_positions
=
max_source_positions
self
.
max_target_positions
=
max_target_positions
self
.
max_target_positions
=
max_target_positions
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
self
.
input_feeding
=
input_feeding
def
__getitem__
(
self
,
index
):
def
__getitem__
(
self
,
index
):
return
{
return
{
...
@@ -119,22 +130,23 @@ class LanguagePairDataset(FairseqDataset):
...
@@ -119,22 +130,23 @@ class LanguagePairDataset(FairseqDataset):
def
collater
(
self
,
samples
):
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch.
"""Merge a list of samples to form a mini-batch.
Return
ed
mini-batches
contain
the following keys:
Return
s
mini-batches
with
the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear
the source sentence of shape `(bsz, src_len)`. Padding will appear
on the left if
``
left_pad_source
``
is True.
on the left if
*
left_pad_source
*
is True.
- `src_lengths` (torch.LongTensor): 1D Tensor of the unpadded lengths
- `src_lengths` (torch.LongTensor): 1D Tensor of the unpadded lengths
of each source sentence of shape `(bsz)`
of each source sentence of shape `(bsz)`
- `prev_output_tokens` (torch.LongTensor): a padded 2D Tensor of
- `prev_output_tokens` (torch.LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one position for
tokens in the target sentence, shifted right by one position for
input feeding/teacher forcing, of shape `(bsz, tgt_len)`. Padding
input feeding/teacher forcing, of shape `(bsz, tgt_len)`. This key
will appear on the left if ``left_pad_target`` is True.
will only be present if *input_feeding* is ``True``. Padding will
appear on the left if *left_pad_target* is ``True``.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
left if
``
left_pad_target
``
is True.
left if
*
left_pad_target
*
is
``
True
``
.
Args:
Args:
samples (List[dict]): samples to collate
samples (List[dict]): samples to collate
...
@@ -145,6 +157,7 @@ class LanguagePairDataset(FairseqDataset):
...
@@ -145,6 +157,7 @@ class LanguagePairDataset(FairseqDataset):
return
collate
(
return
collate
(
samples
,
pad_idx
=
self
.
src_dict
.
pad
(),
eos_idx
=
self
.
src_dict
.
eos
(),
samples
,
pad_idx
=
self
.
src_dict
.
pad
(),
eos_idx
=
self
.
src_dict
.
eos
(),
left_pad_source
=
self
.
left_pad_source
,
left_pad_target
=
self
.
left_pad_target
,
left_pad_source
=
self
.
left_pad_source
,
left_pad_target
=
self
.
left_pad_target
,
input_feeding
=
self
.
input_feeding
,
)
)
def
get_dummy_batch
(
self
,
num_tokens
,
max_positions
,
src_len
=
128
,
tgt_len
=
128
):
def
get_dummy_batch
(
self
,
num_tokens
,
max_positions
,
src_len
=
128
,
tgt_len
=
128
):
...
...
fairseq/data/monolingual_dataset.py
View file @
0e101e9c
...
@@ -25,6 +25,9 @@ def collate(samples, pad_idx, eos_idx):
...
@@ -25,6 +25,9 @@ def collate(samples, pad_idx, eos_idx):
'ntokens'
:
sum
(
len
(
s
[
'target'
])
for
s
in
samples
),
'ntokens'
:
sum
(
len
(
s
[
'target'
])
for
s
in
samples
),
'net_input'
:
{
'net_input'
:
{
'src_tokens'
:
merge
(
'source'
),
'src_tokens'
:
merge
(
'source'
),
'src_lengths'
:
torch
.
LongTensor
([
s
[
'source'
].
numel
()
for
s
in
samples
]),
},
},
'target'
:
merge
(
'target'
),
'target'
:
merge
(
'target'
),
}
}
...
@@ -42,7 +45,7 @@ class MonolingualDataset(FairseqDataset):
...
@@ -42,7 +45,7 @@ class MonolingualDataset(FairseqDataset):
Default: ``True``
Default: ``True``
"""
"""
def
__init__
(
self
,
dataset
,
sizes
,
vocab
,
shuffle
):
def
__init__
(
self
,
dataset
,
sizes
,
vocab
,
shuffle
=
True
):
self
.
dataset
=
dataset
self
.
dataset
=
dataset
self
.
sizes
=
np
.
array
(
sizes
)
self
.
sizes
=
np
.
array
(
sizes
)
self
.
vocab
=
vocab
self
.
vocab
=
vocab
...
@@ -58,7 +61,7 @@ class MonolingualDataset(FairseqDataset):
...
@@ -58,7 +61,7 @@ class MonolingualDataset(FairseqDataset):
def
collater
(
self
,
samples
):
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch.
"""Merge a list of samples to form a mini-batch.
Return
ed
mini-batches
contain
the following keys:
Return
s
mini-batches
with
the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `net_input` (dict): the input to the Model, containing keys:
...
...
fairseq/models/fairseq_model.py
View file @
0e101e9c
...
@@ -5,8 +5,9 @@
...
@@ -5,8 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
.
import
FairseqDecoder
,
FairseqEncoder
from
.
import
FairseqDecoder
,
FairseqEncoder
...
@@ -34,11 +35,19 @@ class BaseFairseqModel(nn.Module):
...
@@ -34,11 +35,19 @@ class BaseFairseqModel(nn.Module):
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
"""Get normalized probabilities (or log probs) from a net's output."""
"""Get normalized probabilities (or log probs) from a net's output."""
if
hasattr
(
self
,
'decoder'
):
return
self
.
decoder
.
get_normalized_probs
(
net_output
,
log_probs
,
sample
)
return
self
.
decoder
.
get_normalized_probs
(
net_output
,
log_probs
,
sample
)
elif
torch
.
is_tensor
(
net_output
):
logits
=
net_output
.
float
()
if
log_probs
:
return
F
.
log_softmax
(
logits
,
dim
=-
1
)
else
:
return
F
.
softmax
(
logits
,
dim
=-
1
)
raise
NotImplementedError
def
max_positions
(
self
):
def
max_positions
(
self
):
"""Maximum length supported by the model."""
"""Maximum length supported by the model."""
r
aise
NotImplementedError
r
eturn
None
def
max_decoder_positions
(
self
):
def
max_decoder_positions
(
self
):
"""Maximum length supported by the decoder."""
"""Maximum length supported by the decoder."""
...
@@ -138,7 +147,7 @@ class FairseqLanguageModel(BaseFairseqModel):
...
@@ -138,7 +147,7 @@ class FairseqLanguageModel(BaseFairseqModel):
self
.
decoder
=
decoder
self
.
decoder
=
decoder
assert
isinstance
(
self
.
decoder
,
FairseqDecoder
)
assert
isinstance
(
self
.
decoder
,
FairseqDecoder
)
def
forward
(
self
,
src_tokens
):
def
forward
(
self
,
src_tokens
,
src_lengths
):
return
self
.
decoder
(
src_tokens
)
return
self
.
decoder
(
src_tokens
)
def
max_positions
(
self
):
def
max_positions
(
self
):
...
...
fairseq/tasks/fairseq_task.py
View file @
0e101e9c
...
@@ -25,7 +25,7 @@ class FairseqTask(object):
...
@@ -25,7 +25,7 @@ class FairseqTask(object):
@
classmethod
@
classmethod
def
setup_task
(
cls
,
args
,
**
kwargs
):
def
setup_task
(
cls
,
args
,
**
kwargs
):
r
aise
NotImplementedError
r
eturn
cls
(
args
)
def
load_dataset
(
self
,
split
,
combine
=
False
):
def
load_dataset
(
self
,
split
,
combine
=
False
):
raise
NotImplementedError
raise
NotImplementedError
...
...
fairseq/utils.py
View file @
0e101e9c
...
@@ -403,6 +403,16 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
...
@@ -403,6 +403,16 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
def
resolve_max_positions
(
*
args
):
def
resolve_max_positions
(
*
args
):
"""Resolve max position constraints from multiple sources."""
"""Resolve max position constraints from multiple sources."""
def
nullsafe_min
(
l
):
minim
=
None
for
item
in
l
:
if
minim
is
None
:
minim
=
item
elif
item
is
not
None
and
item
<
minim
:
minim
=
item
return
minim
max_positions
=
None
max_positions
=
None
for
arg
in
args
:
for
arg
in
args
:
if
max_positions
is
None
:
if
max_positions
is
None
:
...
@@ -411,5 +421,7 @@ def resolve_max_positions(*args):
...
@@ -411,5 +421,7 @@ def resolve_max_positions(*args):
if
isinstance
(
arg
,
float
)
or
isinstance
(
arg
,
int
):
if
isinstance
(
arg
,
float
)
or
isinstance
(
arg
,
int
):
max_positions
=
min
(
max_positions
,
arg
)
max_positions
=
min
(
max_positions
,
arg
)
else
:
else
:
max_positions
=
tuple
(
map
(
min
,
zip
(
max_positions
,
arg
)))
max_positions
=
tuple
(
map
(
nullsafe_min
,
zip
(
max_positions
,
arg
))
)
return
max_positions
return
max_positions
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment