"docs/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "728ac510b08f7fa5c385f2304e4e2cf61305024c"
Unverified Commit 6edf81dd authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Remove more Variable() calls (#198)

parent 74efc214
...@@ -633,6 +633,7 @@ def fconv_lm_dauphin_wikitext103(args): ...@@ -633,6 +633,7 @@ def fconv_lm_dauphin_wikitext103(args):
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,20000,200000') args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,20000,200000')
base_lm_architecture(args) base_lm_architecture(args)
@register_model_architecture('fconv_lm', 'fconv_lm_dauphin_gbw') @register_model_architecture('fconv_lm', 'fconv_lm_dauphin_gbw')
def fconv_lm_dauphin_gbw(args): def fconv_lm_dauphin_gbw(args):
layers = '[(512, 5)]' layers = '[(512, 5)]'
......
...@@ -11,7 +11,6 @@ import numpy as np ...@@ -11,7 +11,6 @@ import numpy as np
import sys import sys
import torch import torch
from torch.autograd import Variable
from fairseq import data, options, tasks, tokenizer, utils from fairseq import data, options, tasks, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
...@@ -131,8 +130,8 @@ def main(args): ...@@ -131,8 +130,8 @@ def main(args):
lengths = lengths.cuda() lengths = lengths.cuda()
translations = translator.generate( translations = translator.generate(
Variable(tokens), tokens,
Variable(lengths), lengths,
maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b), maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
) )
......
...@@ -9,7 +9,6 @@ import torch ...@@ -9,7 +9,6 @@ import torch
import unittest import unittest
from fairseq.modules import ConvTBC from fairseq.modules import ConvTBC
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable
class TestConvTBC(unittest.TestCase): class TestConvTBC(unittest.TestCase):
...@@ -23,8 +22,9 @@ class TestConvTBC(unittest.TestCase): ...@@ -23,8 +22,9 @@ class TestConvTBC(unittest.TestCase):
conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2)) conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2))
conv_tbc.bias.data.copy_(conv1d.bias.data) conv_tbc.bias.data.copy_(conv1d.bias.data)
input_tbc = Variable(torch.randn(7, 2, 4), requires_grad=True) input_tbc = torch.randn(7, 2, 4, requires_grad=True)
input1d = Variable(input_tbc.data.transpose(0, 1).transpose(1, 2), requires_grad=True) input1d = input_tbc.data.transpose(0, 1).transpose(1, 2)
input1d.requires_grad = True
output_tbc = conv_tbc(input_tbc) output_tbc = conv_tbc(input_tbc)
output1d = conv1d(input1d) output1d = conv1d(input1d)
......
...@@ -11,7 +11,7 @@ import unittest ...@@ -11,7 +11,7 @@ import unittest
import torch import torch
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line from fairseq.tokenizer import Tokenizer
class TestDictionary(unittest.TestCase): class TestDictionary(unittest.TestCase):
......
...@@ -9,7 +9,6 @@ import argparse ...@@ -9,7 +9,6 @@ import argparse
import unittest import unittest
import torch import torch
from torch.autograd import Variable
from fairseq.sequence_generator import SequenceGenerator from fairseq.sequence_generator import SequenceGenerator
...@@ -29,11 +28,11 @@ class TestSequenceGenerator(unittest.TestCase): ...@@ -29,11 +28,11 @@ class TestSequenceGenerator(unittest.TestCase):
self.w2 = 5 self.w2 = 5
# construct source data # construct source data
self.src_tokens = Variable(torch.LongTensor([ self.src_tokens = torch.LongTensor([
[self.w1, self.w2, self.eos], [self.w1, self.w2, self.eos],
[self.w1, self.w2, self.eos], [self.w1, self.w2, self.eos],
])) ])
self.src_lengths = Variable(torch.LongTensor([2, 2])) self.src_lengths = torch.LongTensor([2, 2])
args = argparse.Namespace() args = argparse.Namespace()
unk = 0. unk = 0.
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
import unittest import unittest
import torch import torch
from torch.autograd import Variable
from fairseq import utils from fairseq import utils
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch import torch
from torch.autograd import Variable
from fairseq import utils from fairseq import utils
from fairseq.data import Dictionary from fairseq.data import Dictionary
...@@ -156,7 +155,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder): ...@@ -156,7 +155,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
# random attention # random attention
attn = torch.rand(bbsz, tgt_len, src_len) attn = torch.rand(bbsz, tgt_len, src_len)
return Variable(probs), Variable(attn) return probs, attn
def get_normalized_probs(self, net_output, log_probs, _): def get_normalized_probs(self, net_output, log_probs, _):
# the decoder returns probabilities directly # the decoder returns probabilities directly
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment