Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
6edf81dd
Unverified
Commit
6edf81dd
authored
Jun 25, 2018
by
Myle Ott
Committed by
GitHub
Jun 25, 2018
Browse files
Remove more Variable() calls (#198)
parent
74efc214
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
11 additions
and
14 deletions
+11
-14
fairseq/models/fconv.py
fairseq/models/fconv.py
+1
-0
interactive.py
interactive.py
+2
-3
tests/test_convtbc.py
tests/test_convtbc.py
+3
-3
tests/test_dictionary.py
tests/test_dictionary.py
+1
-1
tests/test_sequence_generator.py
tests/test_sequence_generator.py
+3
-4
tests/test_utils.py
tests/test_utils.py
+0
-1
tests/utils.py
tests/utils.py
+1
-2
No files found.
fairseq/models/fconv.py
View file @
6edf81dd
...
...
@@ -633,6 +633,7 @@ def fconv_lm_dauphin_wikitext103(args):
args
.
adaptive_softmax_cutoff
=
getattr
(
args
,
'adaptive_softmax_cutoff'
,
'10000,20000,200000'
)
base_lm_architecture
(
args
)
@
register_model_architecture
(
'fconv_lm'
,
'fconv_lm_dauphin_gbw'
)
def
fconv_lm_dauphin_gbw
(
args
):
layers
=
'[(512, 5)]'
...
...
interactive.py
View file @
6edf81dd
...
...
@@ -11,7 +11,6 @@ import numpy as np
import
sys
import
torch
from
torch.autograd
import
Variable
from
fairseq
import
data
,
options
,
tasks
,
tokenizer
,
utils
from
fairseq.sequence_generator
import
SequenceGenerator
...
...
@@ -131,8 +130,8 @@ def main(args):
lengths
=
lengths
.
cuda
()
translations
=
translator
.
generate
(
Variable
(
tokens
)
,
Variable
(
lengths
)
,
tokens
,
lengths
,
maxlen
=
int
(
args
.
max_len_a
*
tokens
.
size
(
1
)
+
args
.
max_len_b
),
)
...
...
tests/test_convtbc.py
View file @
6edf81dd
...
...
@@ -9,7 +9,6 @@ import torch
import
unittest
from
fairseq.modules
import
ConvTBC
import
torch.nn
as
nn
from
torch.autograd
import
Variable
class
TestConvTBC
(
unittest
.
TestCase
):
...
...
@@ -23,8 +22,9 @@ class TestConvTBC(unittest.TestCase):
conv_tbc
.
weight
.
data
.
copy_
(
conv1d
.
weight
.
data
.
transpose
(
0
,
2
))
conv_tbc
.
bias
.
data
.
copy_
(
conv1d
.
bias
.
data
)
input_tbc
=
Variable
(
torch
.
randn
(
7
,
2
,
4
),
requires_grad
=
True
)
input1d
=
Variable
(
input_tbc
.
data
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
requires_grad
=
True
)
input_tbc
=
torch
.
randn
(
7
,
2
,
4
,
requires_grad
=
True
)
input1d
=
input_tbc
.
data
.
transpose
(
0
,
1
).
transpose
(
1
,
2
)
input1d
.
requires_grad
=
True
output_tbc
=
conv_tbc
(
input_tbc
)
output1d
=
conv1d
(
input1d
)
...
...
tests/test_dictionary.py
View file @
6edf81dd
...
...
@@ -11,7 +11,7 @@ import unittest
import
torch
from
fairseq.data
import
Dictionary
from
fairseq.tokenizer
import
Tokenizer
,
tokenize_line
from
fairseq.tokenizer
import
Tokenizer
class
TestDictionary
(
unittest
.
TestCase
):
...
...
tests/test_sequence_generator.py
View file @
6edf81dd
...
...
@@ -9,7 +9,6 @@ import argparse
import
unittest
import
torch
from
torch.autograd
import
Variable
from
fairseq.sequence_generator
import
SequenceGenerator
...
...
@@ -29,11 +28,11 @@ class TestSequenceGenerator(unittest.TestCase):
self
.
w2
=
5
# construct source data
self
.
src_tokens
=
Variable
(
torch
.
LongTensor
([
self
.
src_tokens
=
torch
.
LongTensor
([
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
[
self
.
w1
,
self
.
w2
,
self
.
eos
],
])
)
self
.
src_lengths
=
Variable
(
torch
.
LongTensor
([
2
,
2
])
)
])
self
.
src_lengths
=
torch
.
LongTensor
([
2
,
2
])
args
=
argparse
.
Namespace
()
unk
=
0.
...
...
tests/test_utils.py
View file @
6edf81dd
...
...
@@ -8,7 +8,6 @@
import
unittest
import
torch
from
torch.autograd
import
Variable
from
fairseq
import
utils
...
...
tests/utils.py
View file @
6edf81dd
...
...
@@ -6,7 +6,6 @@
# can be found in the PATENTS file in the same directory.
import
torch
from
torch.autograd
import
Variable
from
fairseq
import
utils
from
fairseq.data
import
Dictionary
...
...
@@ -156,7 +155,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
# random attention
attn
=
torch
.
rand
(
bbsz
,
tgt_len
,
src_len
)
return
Variable
(
probs
)
,
Variable
(
attn
)
return
probs
,
attn
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
_
):
# the decoder returns probabilities directly
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment