Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
cfd2a3a0
Commit
cfd2a3a0
authored
Sep 20, 2018
by
Alexei Baevski
Committed by
Myle Ott
Sep 25, 2018
Browse files
core changes to support latte collab
parent
fbe8ce65
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
270 additions
and
69 deletions
+270
-69
fairseq/criterions/adaptive_loss.py
fairseq/criterions/adaptive_loss.py
+11
-6
fairseq/data/__init__.py
fairseq/data/__init__.py
+1
-1
fairseq/data/dictionary.py
fairseq/data/dictionary.py
+17
-0
fairseq/data/language_pair_dataset.py
fairseq/data/language_pair_dataset.py
+1
-0
fairseq/data/monolingual_dataset.py
fairseq/data/monolingual_dataset.py
+82
-12
fairseq/data/token_block_dataset.py
fairseq/data/token_block_dataset.py
+11
-3
fairseq/models/fairseq_model.py
fairseq/models/fairseq_model.py
+8
-1
fairseq/models/transformer.py
fairseq/models/transformer.py
+19
-6
fairseq/modules/multihead_attention.py
fairseq/modules/multihead_attention.py
+48
-28
fairseq/optim/fp16_optimizer.py
fairseq/optim/fp16_optimizer.py
+3
-2
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+4
-0
fairseq/tasks/language_modeling.py
fairseq/tasks/language_modeling.py
+64
-9
tests/test_train.py
tests/test_train.py
+1
-1
No files found.
fairseq/criterions/adaptive_loss.py
View file @
cfd2a3a0
...
...
@@ -42,11 +42,14 @@ class AdaptiveLoss(FairseqCriterion):
adaptive_softmax
=
model
.
decoder
.
adaptive_softmax
net_output
=
model
(
**
sample
[
'net_input'
])
target
=
model
.
get_targets
(
sample
,
net_output
)
.
view
(
-
1
)
orig_
target
=
model
.
get_targets
(
sample
,
net_output
)
bsz
=
target
.
size
(
0
)
nsentences
=
orig_target
.
size
(
0
)
orig_target
=
orig_target
.
view
(
-
1
)
logits
,
target
=
adaptive_softmax
(
net_output
[
0
],
target
)
bsz
=
orig_target
.
size
(
0
)
logits
,
target
=
adaptive_softmax
(
net_output
[
0
],
orig_target
)
assert
len
(
target
)
==
len
(
logits
)
loss
=
net_output
[
0
].
new
(
1
if
reduce
else
bsz
).
zero_
()
...
...
@@ -57,11 +60,13 @@ class AdaptiveLoss(FairseqCriterion):
loss
+=
F
.
cross_entropy
(
logits
[
i
],
target
[
i
],
size_average
=
False
,
ignore_index
=
self
.
padding_idx
,
reduce
=
reduce
)
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
orig
=
utils
.
strip_pad
(
orig_target
,
self
.
padding_idx
)
ntokens
=
orig
.
numel
()
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
ntokens
logging_output
=
{
'loss'
:
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
,
'ntokens'
:
sample
[
'
ntokens
'
]
,
'nsentences'
:
sample
[
'target'
].
size
(
0
)
,
'ntokens'
:
ntokens
,
'nsentences'
:
nsentences
,
'sample_size'
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
...
...
fairseq/data/__init__.py
View file @
cfd2a3a0
...
...
@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
.dictionary
import
Dictionary
from
.dictionary
import
Dictionary
,
TruncatedDictionary
from
.fairseq_dataset
import
FairseqDataset
from
.indexed_dataset
import
IndexedDataset
,
IndexedInMemoryDataset
,
IndexedRawTextDataset
from
.language_pair_dataset
import
LanguagePairDataset
...
...
fairseq/data/dictionary.py
View file @
cfd2a3a0
...
...
@@ -199,3 +199,20 @@ class Dictionary(object):
t
=
torch
.
Tensor
(
length
).
uniform_
(
self
.
nspecial
+
1
,
len
(
self
)).
long
()
t
[
-
1
]
=
self
.
eos
()
return
t
class
TruncatedDictionary
(
object
):
def
__init__
(
self
,
wrapped_dict
,
length
):
self
.
__class__
=
type
(
dict
.
__class__
.
__name__
,
(
self
.
__class__
,
dict
.
__class__
),
{})
self
.
__dict__
=
dict
.
__dict__
self
.
wrapped_dict
=
wrapped_dict
self
.
length
=
min
(
len
(
self
.
wrapped_dict
),
length
)
def
__len__
(
self
):
return
self
.
length
def
__getitem__
(
self
,
i
):
if
i
<
self
.
length
:
return
self
.
wrapped_dict
[
i
]
return
self
.
wrapped_dict
.
unk
()
fairseq/data/language_pair_dataset.py
View file @
cfd2a3a0
...
...
@@ -61,6 +61,7 @@ def collate(
'src_lengths'
:
src_lengths
,
},
'target'
:
target
,
'nsentences'
:
samples
[
0
][
'source'
].
size
(
0
),
}
if
prev_output_tokens
is
not
None
:
batch
[
'net_input'
][
'prev_output_tokens'
]
=
prev_output_tokens
...
...
fairseq/data/monolingual_dataset.py
View file @
cfd2a3a0
...
...
@@ -9,27 +9,39 @@ import numpy as np
import
torch
from
.
import
data_utils
,
FairseqDataset
from
typing
import
List
def
collate
(
samples
,
pad_idx
,
eos_idx
):
if
len
(
samples
)
==
0
:
return
{}
def
merge
(
key
):
def
merge
(
key
,
is_list
=
False
):
if
is_list
:
res
=
[]
for
i
in
range
(
len
(
samples
[
0
][
key
])):
res
.
append
(
data_utils
.
collate_tokens
(
[
s
[
key
][
i
]
for
s
in
samples
],
pad_idx
,
eos_idx
,
left_pad
=
False
,
))
return
res
else
:
return
data_utils
.
collate_tokens
(
[
s
[
key
]
for
s
in
samples
],
pad_idx
,
eos_idx
,
left_pad
=
False
,
)
is_target_list
=
isinstance
(
samples
[
0
][
'target'
],
list
)
return
{
'id'
:
torch
.
LongTensor
([
s
[
'id'
]
for
s
in
samples
]),
'ntokens'
:
sum
(
len
(
s
[
'
target
'
])
for
s
in
samples
),
'ntokens'
:
sum
(
len
(
s
[
'
source
'
])
for
s
in
samples
),
'net_input'
:
{
'src_tokens'
:
merge
(
'source'
),
'src_lengths'
:
torch
.
LongTensor
([
s
[
'source'
].
numel
()
for
s
in
samples
]),
},
'target'
:
merge
(
'target'
),
'target'
:
merge
(
'target'
,
is_target_list
),
'nsentences'
:
samples
[
0
][
'source'
].
size
(
0
),
}
...
...
@@ -45,19 +57,75 @@ class MonolingualDataset(FairseqDataset):
Default: ``True``
"""
def
__init__
(
self
,
dataset
,
sizes
,
vocab
,
shuffle
=
True
):
def
__init__
(
self
,
dataset
,
sizes
,
src_vocab
,
tgt_vocab
,
add_eos_for_other_targets
,
shuffle
,
targets
=
None
):
self
.
dataset
=
dataset
self
.
sizes
=
np
.
array
(
sizes
)
self
.
vocab
=
vocab
self
.
vocab
=
src_vocab
self
.
tgt_vocab
=
tgt_vocab
self
.
add_eos_for_other_targets
=
add_eos_for_other_targets
self
.
shuffle
=
shuffle
assert
targets
is
None
or
all
(
t
in
{
'self'
,
'future'
,
'past'
}
for
t
in
targets
),
"targets must be none or one of 'self', 'future', 'past'"
if
targets
is
not
None
and
len
(
targets
)
==
0
:
targets
=
None
self
.
targets
=
targets
def
__getitem__
(
self
,
index
):
source
,
target
=
self
.
dataset
[
index
]
source
,
future_target
,
past_target
=
self
.
dataset
[
index
]
source
,
target
=
self
.
_make_source_target
(
source
,
future_target
,
past_target
)
return
{
'id'
:
index
,
'source'
:
source
,
'target'
:
target
}
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
_make_source_target
(
self
,
source
,
future_target
,
past_target
):
if
self
.
targets
is
not
None
:
target
=
[]
if
self
.
add_eos_for_other_targets
and
((
'self'
in
self
.
targets
)
or
(
'past'
in
self
.
targets
))
\
and
source
[
-
1
]
!=
self
.
vocab
.
eos
():
# append eos at the end of source
source
=
torch
.
cat
([
source
,
source
.
new
([
self
.
vocab
.
eos
()])])
if
'future'
in
self
.
targets
:
future_target
=
torch
.
cat
([
future_target
,
future_target
.
new
([
self
.
vocab
.
pad
()])])
if
'past'
in
self
.
targets
:
# first token is before the start of sentence which is only used in "none" break mode when
# add_eos_for_other_targets is False
past_target
=
torch
.
cat
([
past_target
.
new
([
self
.
vocab
.
pad
()]),
past_target
[
1
:],
source
[
-
2
,
None
]])
for
t
in
self
.
targets
:
if
t
==
'self'
:
target
.
append
(
source
)
elif
t
==
'future'
:
target
.
append
(
future_target
)
elif
t
==
'past'
:
target
.
append
(
past_target
)
else
:
raise
Exception
(
'invalid target '
+
t
)
if
len
(
target
)
==
1
:
target
=
target
[
0
]
else
:
target
=
future_target
return
source
,
self
.
_filter_vocab
(
target
)
def
_filter_vocab
(
self
,
target
):
if
len
(
self
.
tgt_vocab
)
!=
len
(
self
.
vocab
):
def
_filter
(
target
):
mask
=
target
.
ge
(
len
(
self
.
tgt_vocab
))
if
mask
.
any
():
target
[
mask
]
=
self
.
tgt_vocab
.
unk
()
return
target
if
isinstance
(
target
,
list
):
return
[
_filter
(
t
)
for
t
in
target
]
return
_filter
(
target
)
return
target
def
collater
(
self
,
samples
):
"""Merge a list of samples to form a mini-batch.
...
...
@@ -86,8 +154,10 @@ class MonolingualDataset(FairseqDataset):
if
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
):
tgt_len
=
min
(
tgt_len
,
max_positions
)
bsz
=
num_tokens
//
tgt_len
target
=
self
.
vocab
.
dummy_sentence
(
tgt_len
+
1
)
source
,
target
=
target
[:
-
1
],
target
[
1
:]
target
=
self
.
vocab
.
dummy_sentence
(
tgt_len
+
2
)
source
,
past_target
,
future_target
=
target
[
1
:
-
1
],
target
[
2
:],
target
[:
-
2
]
source
,
target
=
self
.
_make_source_target
(
source
,
past_target
,
future_target
)
return
self
.
collater
([
{
'id'
:
i
,
'source'
:
source
,
'target'
:
target
}
for
i
in
range
(
bsz
)
...
...
fairseq/data/token_block_dataset.py
View file @
cfd2a3a0
...
...
@@ -29,11 +29,13 @@ class TokenBlockDataset(torch.utils.data.Dataset):
include_targets: return next tokens as targets
"""
def
__init__
(
self
,
tokens
,
sizes
,
block_size
,
break_mode
=
None
,
include_targets
=
False
):
def
__init__
(
self
,
tokens
,
sizes
,
block_size
,
pad
,
eos
,
break_mode
=
None
,
include_targets
=
False
):
super
().
__init__
()
self
.
tokens
=
tokens
self
.
total_size
=
len
(
tokens
)
self
.
pad
=
pad
self
.
eos
=
eos
self
.
include_targets
=
include_targets
self
.
slice_indices
=
[]
...
...
@@ -81,12 +83,18 @@ class TokenBlockDataset(torch.utils.data.Dataset):
if
self
.
include_targets
:
# target is the sentence, for source, rotate item one token to the left (would start with eos)
# past target is rotated to the left by 2 (padded if its first)
if
s
==
0
:
source
=
np
.
concatenate
([
self
.
tokens
[
-
1
:],
self
.
tokens
[
0
:
e
-
1
]])
source
=
np
.
concatenate
([[
self
.
eos
],
self
.
tokens
[
0
:
e
-
1
]])
past_target
=
np
.
concatenate
([[
self
.
pad
,
self
.
eos
],
self
.
tokens
[
0
:
e
-
2
]])
else
:
source
=
self
.
tokens
[
s
-
1
:
e
-
1
]
if
s
==
1
:
past_target
=
np
.
concatenate
([[
self
.
eos
],
self
.
tokens
[
0
:
e
-
2
]])
else
:
past_target
=
self
.
tokens
[
s
-
2
:
e
-
2
]
return
torch
.
LongTensor
(
source
),
item
return
torch
.
LongTensor
(
source
),
item
,
torch
.
LongTensor
(
past_target
)
return
item
def
__len__
(
self
):
...
...
fairseq/models/fairseq_model.py
View file @
cfd2a3a0
...
...
@@ -65,6 +65,9 @@ class BaseFairseqModel(nn.Module):
def
upgrade_state_dict
(
self
,
state_dict
):
"""Upgrade old state dicts to work with newer code."""
self
.
upgrade_state_dict_named
(
state_dict
,
''
)
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
assert
state_dict
is
not
None
def
do_upgrade
(
m
,
prefix
):
...
...
@@ -79,7 +82,7 @@ class BaseFairseqModel(nn.Module):
c
.
upgrade_state_dict
(
state_dict
)
do_upgrade
(
c
,
name
)
do_upgrade
(
self
,
''
)
do_upgrade
(
self
,
name
)
def
make_generation_fast_
(
self
,
**
kwargs
):
"""Optimize model for faster generation."""
...
...
@@ -196,3 +199,7 @@ class FairseqLanguageModel(BaseFairseqModel):
def
max_positions
(
self
):
"""Maximum length supported by the model."""
return
self
.
decoder
.
max_positions
()
@
property
def
supported_targets
(
self
):
return
{
'future'
}
fairseq/models/transformer.py
View file @
cfd2a3a0
...
...
@@ -213,7 +213,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
else
:
embed_tokens
=
Embedding
(
len
(
task
.
dictionary
),
args
.
decoder_input_dim
,
task
.
dictionary
.
pad
())
decoder
=
TransformerDecoder
(
args
,
task
.
dictionary
,
embed_tokens
,
no_encoder_attn
=
True
,
final_norm
=
False
)
decoder
=
TransformerDecoder
(
args
,
task
.
output_
dictionary
,
embed_tokens
,
no_encoder_attn
=
True
,
final_norm
=
False
)
return
TransformerLanguageModel
(
decoder
)
...
...
@@ -442,6 +442,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x
=
x
.
transpose
(
0
,
1
)
attn
=
None
inner_states
=
[
x
]
# decoder layers
for
layer
in
self
.
layers
:
x
,
attn
=
layer
(
...
...
@@ -449,7 +451,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_out
[
'encoder_out'
]
if
encoder_out
is
not
None
else
None
,
encoder_out
[
'encoder_padding_mask'
]
if
encoder_out
is
not
None
else
None
,
incremental_state
,
self_attn_mask
=
self
.
buffered_future_mask
(
x
)
if
incremental_state
is
None
else
None
,
)
inner_states
.
append
(
x
)
if
self
.
normalize
:
x
=
self
.
layer_norm
(
x
)
...
...
@@ -467,7 +471,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else
:
x
=
F
.
linear
(
x
,
self
.
embed_out
)
return
x
,
attn
return
x
,
{
'
attn
'
:
attn
,
'inner_states'
:
inner_states
}
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
...
...
@@ -475,6 +479,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return
self
.
max_target_positions
return
min
(
self
.
max_target_positions
,
self
.
embed_positions
.
max_positions
())
def
buffered_future_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
0
)
if
not
hasattr
(
self
,
'_future_mask'
)
or
self
.
_future_mask
is
None
or
self
.
_future_mask
.
device
!=
tensor
.
device
:
self
.
_future_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
tensor
.
new
(
dim
,
dim
)),
1
)
if
self
.
_future_mask
.
size
(
0
)
<
dim
:
self
.
_future_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
self
.
_future_mask
.
resize_
(
dim
,
dim
)),
1
)
return
self
.
_future_mask
[:
dim
,
:
dim
]
def
upgrade_state_dict
(
self
,
state_dict
):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if
isinstance
(
self
.
embed_positions
,
SinusoidalPositionalEmbedding
):
...
...
@@ -615,7 +627,8 @@ class TransformerDecoderLayer(nn.Module):
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
need_attn
=
True
def
forward
(
self
,
x
,
encoder_out
,
encoder_padding_mask
,
incremental_state
):
def
forward
(
self
,
x
,
encoder_out
,
encoder_padding_mask
,
incremental_state
,
self_attn_mask
=
None
,
self_attn_padding_mask
=
None
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
...
...
@@ -631,9 +644,10 @@ class TransformerDecoderLayer(nn.Module):
query
=
x
,
key
=
x
,
value
=
x
,
mask_future_timesteps
=
True
,
key_padding_mask
=
self_attn_padding_mask
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
attn_mask
=
self_attn_mask
,
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
...
...
@@ -728,7 +742,6 @@ def base_lm_architecture(args):
# The model training is not stable without this
args
.
decoder_normalize_before
=
True
@
register_model_architecture
(
'transformer_lm'
,
'transformer_lm_big'
)
def
transformer_lm_big
(
args
):
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
1024
)
...
...
@@ -740,7 +753,7 @@ def transformer_lm_big(args):
@
register_model_architecture
(
'transformer_lm'
,
'transformer_lm_wiki103'
)
def
transformer_lm_wiki103
(
args
):
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.3
)
base_lm_architecture
(
args
)
transformer_lm_big
(
args
)
@
register_model_architecture
(
'transformer_lm'
,
'transformer_lm_gbw'
)
...
...
fairseq/modules/multihead_attention.py
View file @
cfd2a3a0
...
...
@@ -18,23 +18,31 @@ class MultiheadAttention(nn.Module):
See "Attention Is All You Need" for more details.
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
dropout
=
0.
,
bias
=
True
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
dropout
=
0.
,
bias
=
True
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
head_dim
=
embed_dim
//
num_heads
assert
self
.
head_dim
*
num_heads
==
self
.
embed_dim
,
"embed_dim must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
_mask
=
None
self
.
scaling
=
self
.
head_dim
**
-
0.5
self
.
in_proj_weight
=
Parameter
(
torch
.
Tensor
(
3
*
embed_dim
,
embed_dim
))
self
.
in_proj_weight
=
Parameter
(
torch
.
Tensor
(
3
*
embed_dim
,
embed_dim
))
if
bias
:
self
.
in_proj_bias
=
Parameter
(
torch
.
Tensor
(
3
*
embed_dim
))
self
.
in_proj_bias
=
Parameter
(
torch
.
Tensor
(
3
*
embed_dim
))
else
:
self
.
register_parameter
(
'in_proj_bias'
,
None
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
if
add_bias_kv
:
self
.
bias_k
=
Parameter
(
torch
.
Tensor
(
1
,
1
,
embed_dim
))
self
.
bias_v
=
Parameter
(
torch
.
Tensor
(
1
,
1
,
embed_dim
))
else
:
self
.
bias_k
=
self
.
bias_v
=
None
self
.
add_zero_attn
=
add_zero_attn
self
.
reset_parameters
()
def
reset_parameters
(
self
):
...
...
@@ -43,15 +51,18 @@ class MultiheadAttention(nn.Module):
if
self
.
in_proj_bias
is
not
None
:
nn
.
init
.
constant_
(
self
.
in_proj_bias
,
0.
)
nn
.
init
.
constant_
(
self
.
out_proj
.
bias
,
0.
)
if
self
.
bias_k
is
not
None
:
nn
.
init
.
xavier_normal_
(
self
.
bias_k
)
if
self
.
bias_v
is
not
None
:
nn
.
init
.
xavier_normal_
(
self
.
bias_v
)
def
forward
(
self
,
query
,
key
,
value
,
mask_future_timesteps
=
False
,
key_padding_mask
=
None
,
incremental_state
=
None
,
need_weights
=
True
,
static_kv
=
False
):
def
forward
(
self
,
query
,
key
,
value
,
key_padding_mask
=
None
,
incremental_state
=
None
,
need_weights
=
True
,
static_kv
=
False
,
attn_mask
=
None
):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value.
Future t
imesteps can be masked
with
the
`
mask_future_timesteps
` argument. Padding elements can be excluded from
query, key and value.
T
imesteps can be masked
by supplying a T x T mask in
the
`
attn_mask
` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
...
...
@@ -103,24 +114,40 @@ class MultiheadAttention(nn.Module):
saved_state
[
'prev_value'
]
=
v
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
if
self
.
bias_k
is
not
None
:
assert
self
.
bias_v
is
not
None
k
=
torch
.
cat
([
k
,
self
.
bias_k
.
repeat
(
1
,
bsz
,
1
)])
v
=
torch
.
cat
([
v
,
self
.
bias_v
.
repeat
(
1
,
bsz
,
1
)])
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
([
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
key_padding_mask
.
new_zeros
(
key_padding_mask
.
size
(
0
),
1
)],
dim
=
1
)
src_len
=
k
.
size
(
0
)
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
q
=
q
.
contiguous
().
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
k
=
k
.
contiguous
().
view
(
src_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
v
=
v
.
contiguous
().
view
(
src_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
q
=
q
.
contiguous
().
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
k
=
k
.
contiguous
().
view
(
src_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
v
=
v
.
contiguous
().
view
(
src_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
if
self
.
add_zero_attn
:
src_len
+=
1
k
=
torch
.
cat
([
k
,
k
.
new_zeros
((
k
.
size
(
0
),
1
)
+
k
.
size
()[
2
:])],
dim
=
1
)
v
=
torch
.
cat
([
v
,
v
.
new_zeros
((
v
.
size
(
0
),
1
)
+
v
.
size
()[
2
:])],
dim
=
1
)
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
([
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
([
key_padding_mask
,
key_padding_mask
.
new_zeros
(
key_padding_mask
.
size
(
0
),
1
)],
dim
=
1
)
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
# only apply masking at training time (when incremental state is None)
if
mask_future_timesteps
and
incremental_state
is
None
:
assert
query
.
size
()
==
key
.
size
(),
\
'mask_future_timesteps only applies to self-attention'
attn_weights
+=
self
.
buffered_mask
(
attn_weights
).
unsqueeze
(
0
)
if
attn_mask
is
not
None
:
attn_weights
+=
attn_mask
.
unsqueeze
(
0
)
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
...
...
@@ -129,6 +156,7 @@ class MultiheadAttention(nn.Module):
float
(
'-inf'
),
).
type_as
(
attn_weights
)
# FP16 support: cast to float and back
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
F
.
softmax
(
attn_weights
.
float
(),
dim
=-
1
).
type_as
(
attn_weights
)
attn_weights
=
F
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
...
@@ -156,10 +184,10 @@ class MultiheadAttention(nn.Module):
return
self
.
_in_proj
(
query
,
end
=
self
.
embed_dim
)
def
in_proj_k
(
self
,
key
):
return
self
.
_in_proj
(
key
,
start
=
self
.
embed_dim
,
end
=
2
*
self
.
embed_dim
)
return
self
.
_in_proj
(
key
,
start
=
self
.
embed_dim
,
end
=
2
*
self
.
embed_dim
)
def
in_proj_v
(
self
,
value
):
return
self
.
_in_proj
(
value
,
start
=
2
*
self
.
embed_dim
)
return
self
.
_in_proj
(
value
,
start
=
2
*
self
.
embed_dim
)
def
_in_proj
(
self
,
input
,
start
=
0
,
end
=
None
):
weight
=
self
.
in_proj_weight
...
...
@@ -169,14 +197,6 @@ class MultiheadAttention(nn.Module):
bias
=
bias
[
start
:
end
]
return
F
.
linear
(
input
,
weight
,
bias
)
def
buffered_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
-
1
)
if
self
.
_mask
is
None
:
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
tensor
.
new
(
dim
,
dim
)),
1
)
if
self
.
_mask
.
size
(
0
)
<
dim
:
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
self
.
_mask
.
resize_
(
dim
,
dim
)),
1
)
return
self
.
_mask
[:
dim
,
:
dim
]
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
...
...
fairseq/optim/fp16_optimizer.py
View file @
cfd2a3a0
...
...
@@ -106,8 +106,9 @@ class FP16Optimizer(optim.FairseqOptimizer):
for
p
in
self
.
params
:
if
not
p
.
requires_grad
:
continue
numel
=
p
.
grad
.
data
.
numel
()
self
.
fp32_params
.
grad
.
data
[
offset
:
offset
+
numel
].
copy_
(
p
.
grad
.
data
.
view
(
-
1
))
grad_data
=
p
.
grad
.
data
if
p
.
grad
is
not
None
else
p
.
data
.
new_zeros
(
p
.
data
.
shape
)
numel
=
grad_data
.
numel
()
self
.
fp32_params
.
grad
.
data
[
offset
:
offset
+
numel
].
copy_
(
grad_data
.
view
(
-
1
))
offset
+=
numel
# correct for dynamic loss scaler
...
...
fairseq/sequence_generator.py
View file @
cfd2a3a0
...
...
@@ -507,7 +507,11 @@ class SequenceGenerator(object):
decoder_out
=
list
(
model
.
decoder
(
tokens
,
encoder_out
))
decoder_out
[
0
]
=
decoder_out
[
0
][:,
-
1
,
:]
attn
=
decoder_out
[
1
]
if
type
(
attn
)
is
dict
:
attn
=
attn
[
'attn'
]
if
attn
is
not
None
:
if
type
(
attn
)
is
dict
:
attn
=
attn
[
'attn'
]
attn
=
attn
[:,
-
1
,
:]
probs
=
model
.
get_normalized_probs
(
decoder_out
,
log_probs
=
log_probs
)
return
probs
,
attn
fairseq/tasks/language_modeling.py
View file @
cfd2a3a0
...
...
@@ -13,7 +13,7 @@ from torch.utils.data import ConcatDataset
from
fairseq.data
import
(
Dictionary
,
IndexedInMemoryDataset
,
IndexedRawTextDataset
,
MonolingualDataset
,
TokenBlockDataset
,
MonolingualDataset
,
TokenBlockDataset
,
TruncatedDictionary
)
from
.
import
FairseqTask
,
register_task
...
...
@@ -25,7 +25,14 @@ class LanguageModelingTask(FairseqTask):
Train a language model.
Args:
dictionary (Dictionary): the dictionary for the language model
dictionary (Dictionary): the dictionary for the input of the language model
output_dictionary (Dictionary): the dictionary for the output of the language model.
In most cases it will be the same as dictionary, but could possibly be a more limited
version of the dictionary (if --output-dictionary-size is used).
targets (List[str]): list of the target types that the language model should predict.
Can be one of "self", "future", and "past". Defaults to "future".
.. note::
...
...
@@ -55,10 +62,23 @@ class LanguageModelingTask(FairseqTask):
help
=
'max number of tokens per sample for LM dataset'
)
parser
.
add_argument
(
'--raw-text'
,
default
=
False
,
action
=
'store_true'
,
help
=
'load raw text dataset'
)
def
__init__
(
self
,
args
,
dictionary
):
parser
.
add_argument
(
'--output-dictionary-size'
,
default
=-
1
,
type
=
int
,
help
=
'limit the size of output dictionary'
)
parser
.
add_argument
(
'--self-target'
,
action
=
'store_true'
,
help
=
'include self target'
)
parser
.
add_argument
(
'--future-target'
,
action
=
'store_true'
,
help
=
'include future target'
)
parser
.
add_argument
(
'--past-target'
,
action
=
'store_true'
,
help
=
'include past target'
)
def
__init__
(
self
,
args
,
dictionary
,
output_dictionary
,
targets
=
None
):
super
().
__init__
(
args
)
self
.
dictionary
=
dictionary
self
.
output_dictionary
=
output_dictionary
if
targets
is
None
:
targets
=
[
'future'
]
self
.
targets
=
targets
@
classmethod
def
setup_task
(
cls
,
args
,
**
kwargs
):
...
...
@@ -69,7 +89,36 @@ class LanguageModelingTask(FairseqTask):
"""
dictionary
=
Dictionary
.
load
(
os
.
path
.
join
(
args
.
data
,
'dict.txt'
))
print
(
'| dictionary: {} types'
.
format
(
len
(
dictionary
)))
return
cls
(
args
,
dictionary
)
output_dictionary
=
dictionary
if
args
.
output_dictionary_size
>=
0
:
output_dictionary
=
TruncatedDictionary
(
dictionary
,
args
.
output_dictionary_size
)
# upgrade old checkpoints
if
hasattr
(
args
,
'exclude_self_target'
):
args
.
self_target
=
not
args
.
exclude_self_target
targets
=
[]
if
args
.
self_target
:
targets
.
append
(
'self'
)
if
args
.
future_target
:
targets
.
append
(
'future'
)
if
args
.
past_target
:
targets
.
append
(
'past'
)
if
len
(
targets
)
==
0
:
# standard language modeling
targets
=
[
'future'
]
return
cls
(
args
,
dictionary
,
output_dictionary
,
targets
=
targets
)
def
build_model
(
self
,
args
):
model
=
super
().
build_model
(
args
)
for
target
in
self
.
targets
:
if
target
not
in
model
.
supported_targets
:
raise
ValueError
(
'Unsupported language modeling target: {}'
.
format
(
target
))
return
model
def
load_dataset
(
self
,
split
,
combine
=
False
):
"""Load a given dataset split.
...
...
@@ -98,8 +147,8 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets
.
append
(
TokenBlockDataset
(
tokens
,
ds
.
sizes
,
self
.
args
.
tokens_per_sample
,
self
.
args
.
sample_break_mode
,
include_targets
=
True
tokens
,
ds
.
sizes
,
self
.
args
.
tokens_per_sample
,
pad
=
self
.
dictionary
.
pad
(),
eos
=
self
.
dictionary
.
eos
()
,
break_mode
=
self
.
args
.
sample_break_mode
,
include_targets
=
True
,
))
print
(
'| {} {} {} examples'
.
format
(
self
.
args
.
data
,
split_k
,
len
(
loaded_datasets
[
-
1
])))
...
...
@@ -114,10 +163,16 @@ class LanguageModelingTask(FairseqTask):
dataset
=
ConcatDataset
(
loaded_datasets
)
sizes
=
np
.
concatenate
([
ds
.
sizes
for
ds
in
loaded_datasets
])
self
.
datasets
[
split
]
=
MonolingualDataset
(
dataset
,
sizes
,
self
.
dictionary
,
shuffle
=
False
)
add_eos_for_other_targets
=
self
.
args
.
sample_break_mode
is
not
None
and
self
.
args
.
sample_break_mode
!=
'none'
self
.
datasets
[
split
]
=
MonolingualDataset
(
dataset
,
sizes
,
self
.
dictionary
,
self
.
output_dictionary
,
add_eos_for_other_targets
=
add_eos_for_other_targets
,
shuffle
=
False
,
targets
=
self
.
targets
,
)
@
property
def
target_dictionary
(
self
):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return
self
.
dictionary
return
self
.
output_
dictionary
tests/test_train.py
View file @
cfd2a3a0
...
...
@@ -40,7 +40,7 @@ def mock_dict():
def
get_trainer_and_epoch_itr
(
epoch
,
epoch_size
,
num_updates
,
iterations_in_epoch
):
tokens
=
torch
.
LongTensor
(
list
(
range
(
epoch_size
)))
tokens_ds
=
data
.
TokenBlockDataset
(
tokens
,
[
len
(
tokens
)],
1
,
include_targets
=
False
)
tokens_ds
=
data
.
TokenBlockDataset
(
tokens
,
sizes
=
[
len
(
tokens
)],
block_size
=
1
,
pad
=
0
,
eos
=
1
,
include_targets
=
False
)
trainer
=
mock_trainer
(
epoch
,
num_updates
,
iterations_in_epoch
)
dataset
=
data
.
LanguagePairDataset
(
tokens_ds
,
tokens_ds
.
sizes
,
mock_dict
(),
shuffle
=
False
)
epoch_itr
=
data
.
EpochBatchIterator
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment