Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
753935ef
Commit
753935ef
authored
Aug 27, 2018
by
Myle Ott
Browse files
Merge internal changes
parent
c7c567a7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
47 additions
and
21 deletions
+47
-21
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+12
-8
fairseq/models/fairseq_model.py
fairseq/models/fairseq_model.py
+8
-0
fairseq/modules/conv_tbc.py
fairseq/modules/conv_tbc.py
+1
-1
fairseq/modules/multihead_attention.py
fairseq/modules/multihead_attention.py
+4
-9
fairseq/modules/sinusoidal_positional_embedding.py
fairseq/modules/sinusoidal_positional_embedding.py
+12
-1
fairseq/utils.py
fairseq/utils.py
+10
-2
No files found.
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
753935ef
...
...
@@ -34,6 +34,17 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
'net_input'
])
loss
,
nll_loss
=
self
.
compute_loss
(
model
,
net_output
,
sample
,
reduce
=
reduce
)
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
,
'nll_loss'
:
utils
.
item
(
nll_loss
.
data
)
if
reduce
else
nll_loss
.
data
,
'ntokens'
:
sample
[
'ntokens'
],
'sample_size'
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
def
compute_loss
(
self
,
model
,
net_output
,
sample
,
reduce
=
True
):
lprobs
=
model
.
get_normalized_probs
(
net_output
,
log_probs
=
True
)
lprobs
=
lprobs
.
view
(
-
1
,
lprobs
.
size
(
-
1
))
target
=
model
.
get_targets
(
sample
,
net_output
).
view
(
-
1
,
1
)
...
...
@@ -45,15 +56,8 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
smooth_loss
=
smooth_loss
.
sum
()
eps_i
=
self
.
eps
/
lprobs
.
size
(
-
1
)
loss
=
(
1.
-
self
.
eps
)
*
nll_loss
+
eps_i
*
smooth_loss
return
loss
,
nll_loss
sample_size
=
sample
[
'target'
].
size
(
0
)
if
self
.
args
.
sentence_avg
else
sample
[
'ntokens'
]
logging_output
=
{
'loss'
:
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
,
'nll_loss'
:
utils
.
item
(
nll_loss
.
data
)
if
reduce
else
nll_loss
.
data
,
'ntokens'
:
sample
[
'ntokens'
],
'sample_size'
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
aggregate_logging_outputs
(
logging_outputs
):
...
...
fairseq/models/fairseq_model.py
View file @
753935ef
...
...
@@ -100,6 +100,14 @@ class BaseFairseqModel(nn.Module):
self
.
eval
()
self
.
train
=
train
def
prepare_for_onnx_export_
(
self
,
**
kwargs
):
"""Make model exportable via ONNX trace."""
def
apply_prepare_for_onnx_export_
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'prepare_for_onnx_export_'
):
module
.
prepare_for_onnx_export_
(
**
kwargs
)
self
.
apply
(
apply_prepare_for_onnx_export_
)
class
FairseqModel
(
BaseFairseqModel
):
"""Base class for encoder-decoder models."""
...
...
fairseq/modules/conv_tbc.py
View file @
753935ef
...
...
@@ -27,7 +27,7 @@ class ConvTBC(torch.nn.Module):
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
Tensor
(
out_channels
))
def
forward
(
self
,
input
):
return
input
.
contiguous
()
.
conv_tbc
(
self
.
weight
,
self
.
bias
,
self
.
padding
[
0
])
return
torch
.
conv_tbc
(
input
.
contiguous
()
,
self
.
weight
,
self
.
bias
,
self
.
padding
[
0
])
def
__repr__
(
self
):
s
=
(
'{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
...
...
fairseq/modules/multihead_attention.py
View file @
753935ef
...
...
@@ -161,17 +161,12 @@ class MultiheadAttention(nn.Module):
def
in_proj_v
(
self
,
value
):
return
self
.
_in_proj
(
value
,
start
=
2
*
self
.
embed_dim
)
def
_in_proj
(
self
,
input
,
start
=
None
,
end
=
None
):
def
_in_proj
(
self
,
input
,
start
=
0
,
end
=
None
):
weight
=
self
.
in_proj_weight
bias
=
self
.
in_proj_bias
if
end
is
not
None
:
weight
=
weight
[:
end
,
:]
if
bias
is
not
None
:
bias
=
bias
[:
end
]
if
start
is
not
None
:
weight
=
weight
[
start
:,
:]
if
bias
is
not
None
:
bias
=
bias
[
start
:]
weight
=
weight
[
start
:
end
,
:]
if
bias
is
not
None
:
bias
=
bias
[
start
:
end
]
return
F
.
linear
(
input
,
weight
,
bias
)
def
buffered_mask
(
self
,
tensor
):
...
...
fairseq/modules/sinusoidal_positional_embedding.py
View file @
753935ef
...
...
@@ -30,8 +30,12 @@ class SinusoidalPositionalEmbedding(nn.Module):
embedding_dim
,
padding_idx
,
)
self
.
onnx_trace
=
False
self
.
register_buffer
(
'_float_tensor'
,
torch
.
FloatTensor
(
1
))
def
prepare_for_onnx_export_
(
self
):
self
.
onnx_trace
=
True
@
staticmethod
def
get_embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
None
):
"""Build sinusoidal embeddings.
...
...
@@ -68,7 +72,14 @@ class SinusoidalPositionalEmbedding(nn.Module):
# positions is the same for every token when decoding a single step
return
self
.
weights
[
self
.
padding_idx
+
seq_len
,
:].
expand
(
bsz
,
1
,
-
1
)
positions
=
utils
.
make_positions
(
input
.
data
,
self
.
padding_idx
,
self
.
left_pad
)
positions
=
utils
.
make_positions
(
input
,
self
.
padding_idx
,
self
.
left_pad
,
self
.
onnx_trace
)
if
self
.
onnx_trace
:
bsz
=
torch
.
onnx
.
operators
.
shape_as_tensor
(
input
)[
0
]
seq_len
=
torch
.
onnx
.
operators
.
shape_as_tensor
(
input
)[
1
]
flat_embeddings
=
self
.
weights
.
detach
().
index_select
(
0
,
positions
.
view
(
-
1
))
embedding_shape
=
torch
.
cat
((
bsz
.
view
(
1
),
seq_len
.
view
(
1
),
torch
.
LongTensor
([
-
1
])))
embeddings
=
torch
.
onnx
.
operators
.
reshape_from_tensor_shape
(
flat_embeddings
,
embedding_shape
)
return
embeddings
return
self
.
weights
.
index_select
(
0
,
positions
.
view
(
-
1
)).
view
(
bsz
,
seq_len
,
-
1
).
detach
()
def
max_positions
(
self
):
...
...
fairseq/utils.py
View file @
753935ef
...
...
@@ -46,7 +46,7 @@ def save_state(filename, args, model, criterion, optimizer, lr_scheduler,
extra_state
=
{}
state_dict
=
{
'args'
:
args
,
'model'
:
convert_
state_dict
_type
(
model
.
state_dict
())
,
'model'
:
model
.
state_dict
()
if
model
else
{}
,
'optimizer_history'
:
optim_history
+
[
{
'criterion_name'
:
criterion
.
__class__
.
__name__
,
...
...
@@ -298,7 +298,7 @@ def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dic
return
hypo_tokens
,
hypo_str
,
alignment
def
make_positions
(
tensor
,
padding_idx
,
left_pad
):
def
make_positions
(
tensor
,
padding_idx
,
left_pad
,
onnx_trace
=
False
):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1.
...
...
@@ -306,6 +306,14 @@ def make_positions(tensor, padding_idx, left_pad):
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
if
onnx_trace
:
range_buf
=
torch
.
_dim_arange
(
like
=
tensor
,
dim
=
1
)
+
padding_idx
+
1
mask
=
tensor
.
ne
(
padding_idx
)
positions
=
range_buf
.
expand_as
(
tensor
)
if
left_pad
:
positions
=
positions
-
mask
.
size
(
1
)
+
mask
.
long
().
sum
(
dim
=
1
).
unsqueeze
(
1
)
return
positions
*
mask
.
long
()
+
positions
*
(
1
-
mask
.
long
())
max_pos
=
padding_idx
+
1
+
tensor
.
size
(
1
)
if
not
hasattr
(
make_positions
,
'range_buf'
):
make_positions
.
range_buf
=
tensor
.
new
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment