Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
5eddda8b
Commit
5eddda8b
authored
Jan 09, 2018
by
Myle Ott
Browse files
Save dictionary in model base classes
parent
08a74a32
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
10 additions
and
12 deletions
+10
-12
fairseq/models/fairseq_decoder.py
fairseq/models/fairseq_decoder.py
+2
-1
fairseq/models/fairseq_encoder.py
fairseq/models/fairseq_encoder.py
+2
-1
fairseq/models/fairseq_incremental_decoder.py
fairseq/models/fairseq_incremental_decoder.py
+2
-2
fairseq/models/fconv.py
fairseq/models/fconv.py
+2
-4
fairseq/models/lstm.py
fairseq/models/lstm.py
+2
-4
No files found.
fairseq/models/fairseq_decoder.py
View file @
5eddda8b
...
...
@@ -13,8 +13,9 @@ import torch.nn.functional as F
class
FairseqDecoder
(
nn
.
Module
):
"""Base class for decoders."""
def
__init__
(
self
):
def
__init__
(
self
,
dictionary
):
super
().
__init__
()
self
.
dictionary
=
dictionary
def
get_normalized_probs
(
self
,
net_output
,
log_probs
):
"""Get normalized probabilities (or log probs) from a net's output."""
...
...
fairseq/models/fairseq_encoder.py
View file @
5eddda8b
...
...
@@ -12,8 +12,9 @@ import torch.nn as nn
class
FairseqEncoder
(
nn
.
Module
):
"""Base class for encoders."""
def
__init__
(
self
):
def
__init__
(
self
,
dictionary
):
super
().
__init__
()
self
.
dictionary
=
dictionary
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
...
...
fairseq/models/fairseq_incremental_decoder.py
View file @
5eddda8b
...
...
@@ -12,8 +12,8 @@ from . import FairseqDecoder
class
FairseqIncrementalDecoder
(
FairseqDecoder
):
"""Base class for incremental decoders."""
def
__init__
(
self
):
super
().
__init__
()
def
__init__
(
self
,
dictionary
):
super
().
__init__
(
dictionary
)
self
.
_is_incremental_eval
=
False
self
.
_incremental_state
=
{}
...
...
fairseq/models/fconv.py
View file @
5eddda8b
...
...
@@ -28,8 +28,7 @@ class FConvEncoder(FairseqEncoder):
"""Convolutional encoder"""
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
max_positions
=
1024
,
convolutions
=
((
512
,
3
),)
*
20
,
dropout
=
0.1
):
super
().
__init__
()
self
.
dictionary
=
dictionary
super
().
__init__
(
dictionary
)
self
.
dropout
=
dropout
self
.
num_attention_layers
=
None
...
...
@@ -137,9 +136,8 @@ class FConvDecoder(FairseqIncrementalDecoder):
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
out_embed_dim
=
256
,
max_positions
=
1024
,
convolutions
=
((
512
,
3
),)
*
20
,
attention
=
True
,
dropout
=
0.1
):
super
().
__init__
()
super
().
__init__
(
dictionary
)
self
.
register_buffer
(
'version'
,
torch
.
Tensor
([
2
]))
self
.
dictionary
=
dictionary
self
.
dropout
=
dropout
in_channels
=
convolutions
[
0
][
0
]
...
...
fairseq/models/lstm.py
View file @
5eddda8b
...
...
@@ -23,8 +23,7 @@ class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
def
__init__
(
self
,
dictionary
,
embed_dim
=
512
,
num_layers
=
1
,
dropout_in
=
0.1
,
dropout_out
=
0.1
):
super
().
__init__
()
self
.
dictionary
=
dictionary
super
().
__init__
(
dictionary
)
self
.
dropout_in
=
dropout_in
self
.
dropout_out
=
dropout_out
...
...
@@ -108,8 +107,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
def
__init__
(
self
,
dictionary
,
encoder_embed_dim
=
512
,
embed_dim
=
512
,
out_embed_dim
=
512
,
num_layers
=
1
,
dropout_in
=
0.1
,
dropout_out
=
0.1
,
attention
=
True
):
super
().
__init__
()
self
.
dictionary
=
dictionary
super
().
__init__
(
dictionary
)
self
.
dropout_in
=
dropout_in
self
.
dropout_out
=
dropout_out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment