Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
a09fe803
Commit
a09fe803
authored
Dec 22, 2017
by
Myle Ott
Browse files
Fix BeamableMM
parent
9f7c3ec6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
10 deletions
+13
-10
fairseq/models/fairseq_incremental_decoder.py
fairseq/models/fairseq_incremental_decoder.py
+6
-4
fairseq/models/fairseq_model.py
fairseq/models/fairseq_model.py
+5
-5
fairseq/models/fconv.py
fairseq/models/fconv.py
+2
-1
No files found.
fairseq/models/fairseq_incremental_decoder.py
View file @
a09fe803
...
@@ -110,7 +110,9 @@ class FairseqIncrementalDecoder(FairseqDecoder):
...
@@ -110,7 +110,9 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def
set_beam_size
(
self
,
beam_size
):
def
set_beam_size
(
self
,
beam_size
):
"""Sets the beam size in the decoder and all children."""
"""Sets the beam size in the decoder and all children."""
def
apply_set_beam_size
(
module
):
if
getattr
(
self
,
'_beam_size'
,
-
1
)
!=
beam_size
:
if
module
!=
self
and
hasattr
(
module
,
'set_beam_size'
):
def
apply_set_beam_size
(
module
):
module
.
set_beam_size
(
beam_size
)
if
module
!=
self
and
hasattr
(
module
,
'set_beam_size'
):
self
.
apply
(
apply_set_beam_size
)
module
.
set_beam_size
(
beam_size
)
self
.
apply
(
apply_set_beam_size
)
self
.
_beam_size
=
beam_size
fairseq/models/fairseq_model.py
View file @
a09fe803
...
@@ -62,6 +62,11 @@ class FairseqModel(nn.Module):
...
@@ -62,6 +62,11 @@ class FairseqModel(nn.Module):
return
return
self
.
apply
(
apply_remove_weight_norm
)
self
.
apply
(
apply_remove_weight_norm
)
def
apply_make_generation_fast_
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'make_generation_fast_'
):
module
.
make_generation_fast_
(
**
kwargs
)
self
.
apply
(
apply_make_generation_fast_
)
def
train
(
mode
):
def
train
(
mode
):
if
mode
:
if
mode
:
raise
RuntimeError
(
'cannot train after make_generation_fast'
)
raise
RuntimeError
(
'cannot train after make_generation_fast'
)
...
@@ -69,8 +74,3 @@ class FairseqModel(nn.Module):
...
@@ -69,8 +74,3 @@ class FairseqModel(nn.Module):
# this model should no longer be used for training
# this model should no longer be used for training
self
.
eval
()
self
.
eval
()
self
.
train
=
train
self
.
train
=
train
def
apply_make_generation_fast_
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'make_generation_fast_'
):
module
.
make_generation_fast_
(
**
kwargs
)
self
.
apply
(
apply_make_generation_fast_
)
fairseq/models/fconv.py
View file @
a09fe803
...
@@ -145,7 +145,8 @@ class AttentionLayer(nn.Module):
...
@@ -145,7 +145,8 @@ class AttentionLayer(nn.Module):
def
make_generation_fast_
(
self
,
beamable_mm_beam_size
=
None
,
**
kwargs
):
def
make_generation_fast_
(
self
,
beamable_mm_beam_size
=
None
,
**
kwargs
):
"""Replace torch.bmm with BeamableMM."""
"""Replace torch.bmm with BeamableMM."""
if
beamable_mm_beam_size
is
not
None
:
if
beamable_mm_beam_size
is
not
None
:
self
.
bmm
=
BeamableMM
(
beamable_mm_beam_size
)
del
self
.
bmm
self
.
add_module
(
'bmm'
,
BeamableMM
(
beamable_mm_beam_size
))
class
FConvDecoder
(
FairseqIncrementalDecoder
):
class
FConvDecoder
(
FairseqIncrementalDecoder
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment