Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
50fdf591
Commit
50fdf591
authored
Nov 11, 2017
by
Myle Ott
Browse files
Don't call forward directly (prefer module(x) to module.forward(x))
parent
5c7f4954
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
19 deletions
+13
-19
fairseq/models/fairseq_incremental_decoder.py
fairseq/models/fairseq_incremental_decoder.py
+5
-8
fairseq/models/fconv.py
fairseq/models/fconv.py
+7
-0
fairseq/modules/linearized_convolution.py
fairseq/modules/linearized_convolution.py
+0
-7
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+1
-4
No files found.
fairseq/models/fairseq_incremental_decoder.py
View file @
50fdf591
...
@@ -18,12 +18,10 @@ class FairseqIncrementalDecoder(FairseqDecoder):
...
@@ -18,12 +18,10 @@ class FairseqIncrementalDecoder(FairseqDecoder):
self
.
_incremental_state
=
{}
self
.
_incremental_state
=
{}
def
forward
(
self
,
tokens
,
encoder_out
):
def
forward
(
self
,
tokens
,
encoder_out
):
raise
NotImplementedError
if
self
.
_is_incremental_eval
:
raise
NotImplementedError
def
incremental_forward
(
self
,
tokens
,
encoder_out
):
else
:
"""Forward pass for one time step."""
raise
NotImplementedError
# keep only the last token for incremental forward pass
return
self
.
forward
(
tokens
[:,
-
1
:],
encoder_out
)
def
incremental_inference
(
self
):
def
incremental_inference
(
self
):
"""Context manager for incremental inference.
"""Context manager for incremental inference.
...
@@ -38,8 +36,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
...
@@ -38,8 +36,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
```
```
with model.decoder.incremental_inference():
with model.decoder.incremental_inference():
for step in range(maxlen):
for step in range(maxlen):
out, _ = model.decoder.incremental_forward(
out, _ = model.decoder(tokens[:, :step], encoder_out)
tokens[:, :step], encoder_out)
probs = torch.nn.functional.log_softmax(out[:, -1, :])
probs = torch.nn.functional.log_softmax(out[:, -1, :])
```
```
"""
"""
...
...
fairseq/models/fconv.py
View file @
50fdf591
...
@@ -185,6 +185,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
...
@@ -185,6 +185,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
self
.
fc3
=
Linear
(
out_embed_dim
,
num_embeddings
,
dropout
=
dropout
)
def
forward
(
self
,
input_tokens
,
encoder_out
):
def
forward
(
self
,
input_tokens
,
encoder_out
):
if
self
.
_is_incremental_eval
:
return
self
.
incremental_forward
(
input_tokens
,
encoder_out
)
else
:
return
self
.
batch_forward
(
input_tokens
,
encoder_out
)
def
batch_forward
(
self
,
input_tokens
,
encoder_out
):
"""Forward pass for decoding multiple time steps in batch mode."""
positions
=
Variable
(
make_positions
(
input_tokens
.
data
,
self
.
dictionary
.
pad
(),
positions
=
Variable
(
make_positions
(
input_tokens
.
data
,
self
.
dictionary
.
pad
(),
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
))
left_pad
=
LanguagePairDataset
.
LEFT_PAD_TARGET
))
return
self
.
_forward
(
input_tokens
,
positions
,
encoder_out
)
return
self
.
_forward
(
input_tokens
,
positions
,
encoder_out
)
...
...
fairseq/modules/linearized_convolution.py
View file @
50fdf591
...
@@ -50,13 +50,6 @@ class LinearizedConvolution(ConvTBC):
...
@@ -50,13 +50,6 @@ class LinearizedConvolution(ConvTBC):
call reorder_incremental_state. To apply to fresh inputs, call
call reorder_incremental_state. To apply to fresh inputs, call
clear_incremental_state.
clear_incremental_state.
"""
"""
if
self
.
training
or
not
self
.
_is_incremental_eval
:
raise
RuntimeError
(
'incremental_forward only supports incremental evaluation'
)
# run forward pre hooks (e.g., weight norm)
for
hook
in
self
.
_forward_pre_hooks
.
values
():
hook
(
self
,
input
)
# reshape weight
# reshape weight
weight
=
self
.
_get_linearized_weight
()
weight
=
self
.
_get_linearized_weight
()
kw
=
self
.
kernel_size
[
0
]
kw
=
self
.
kernel_size
[
0
]
...
...
fairseq/sequence_generator.py
View file @
50fdf591
...
@@ -325,10 +325,7 @@ class SequenceGenerator(object):
...
@@ -325,10 +325,7 @@ class SequenceGenerator(object):
avg_probs
=
None
avg_probs
=
None
avg_attn
=
None
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
if
isinstance
(
model
.
decoder
,
FairseqIncrementalDecoder
):
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
)
decoder_out
,
attn
=
model
.
decoder
.
incremental_forward
(
tokens
,
encoder_out
)
else
:
decoder_out
,
attn
=
model
.
decoder
.
forward
(
tokens
,
encoder_out
)
probs
=
F
.
softmax
(
decoder_out
[:,
-
1
,
:]).
data
probs
=
F
.
softmax
(
decoder_out
[:,
-
1
,
:]).
data
attn
=
attn
[:,
-
1
,
:].
data
attn
=
attn
[:,
-
1
,
:].
data
if
avg_probs
is
None
or
avg_attn
is
None
:
if
avg_probs
is
None
or
avg_attn
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment