Unverified Commit 6edf81dd authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Remove more Variable() calls (#198)

parent 74efc214
......@@ -633,6 +633,7 @@ def fconv_lm_dauphin_wikitext103(args):
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,20000,200000')
base_lm_architecture(args)
@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_gbw')
def fconv_lm_dauphin_gbw(args):
layers = '[(512, 5)]'
......
......@@ -11,7 +11,6 @@ import numpy as np
import sys
import torch
from torch.autograd import Variable
from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator
......@@ -131,8 +130,8 @@ def main(args):
lengths = lengths.cuda()
translations = translator.generate(
Variable(tokens),
Variable(lengths),
tokens,
lengths,
maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
)
......
......@@ -9,7 +9,6 @@ import torch
import unittest
from fairseq.modules import ConvTBC
import torch.nn as nn
from torch.autograd import Variable
class TestConvTBC(unittest.TestCase):
......@@ -23,8 +22,9 @@ class TestConvTBC(unittest.TestCase):
conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2))
conv_tbc.bias.data.copy_(conv1d.bias.data)
input_tbc = Variable(torch.randn(7, 2, 4), requires_grad=True)
input1d = Variable(input_tbc.data.transpose(0, 1).transpose(1, 2), requires_grad=True)
input_tbc = torch.randn(7, 2, 4, requires_grad=True)
input1d = input_tbc.data.transpose(0, 1).transpose(1, 2)
input1d.requires_grad = True
output_tbc = conv_tbc(input_tbc)
output1d = conv1d(input1d)
......
......@@ -11,7 +11,7 @@ import unittest
import torch
from fairseq.data import Dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line
from fairseq.tokenizer import Tokenizer
class TestDictionary(unittest.TestCase):
......
......@@ -9,7 +9,6 @@ import argparse
import unittest
import torch
from torch.autograd import Variable
from fairseq.sequence_generator import SequenceGenerator
......@@ -29,11 +28,11 @@ class TestSequenceGenerator(unittest.TestCase):
self.w2 = 5
# construct source data
self.src_tokens = Variable(torch.LongTensor([
self.src_tokens = torch.LongTensor([
[self.w1, self.w2, self.eos],
[self.w1, self.w2, self.eos],
]))
self.src_lengths = Variable(torch.LongTensor([2, 2]))
])
self.src_lengths = torch.LongTensor([2, 2])
args = argparse.Namespace()
unk = 0.
......
......@@ -8,7 +8,6 @@
import unittest
import torch
from torch.autograd import Variable
from fairseq import utils
......
......@@ -6,7 +6,6 @@
# can be found in the PATENTS file in the same directory.
import torch
from torch.autograd import Variable
from fairseq import utils
from fairseq.data import Dictionary
......@@ -156,7 +155,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
# random attention
attn = torch.rand(bbsz, tgt_len, src_len)
return Variable(probs), Variable(attn)
return probs, attn
def get_normalized_probs(self, net_output, log_probs, _):
# the decoder returns probabilities directly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment