"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "aed7499a8d81de78bb1692d7a0745d3890618b0e"
Commit e89329d6 authored by Myle Ott's avatar Myle Ott
Browse files

Updates for latest PyTorch

parent ff68a9ef
...@@ -26,7 +26,7 @@ class AdaptiveLoss(FairseqCriterion): ...@@ -26,7 +26,7 @@ class AdaptiveLoss(FairseqCriterion):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
1) the loss, as a Variable 1) the loss
2) the sample size, which is used as the denominator for the gradient 2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training 3) logging outputs to display while training
""" """
......
...@@ -23,7 +23,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -23,7 +23,7 @@ class CrossEntropyCriterion(FairseqCriterion):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
1) the loss, as a Variable 1) the loss
2) the sample size, which is used as the denominator for the gradient 2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training 3) logging outputs to display while training
""" """
......
...@@ -24,7 +24,7 @@ class FairseqCriterion(_Loss): ...@@ -24,7 +24,7 @@ class FairseqCriterion(_Loss):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
1) the loss, as a Variable 1) the loss
2) the sample size, which is used as the denominator for the gradient 2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training 3) logging outputs to display while training
""" """
......
...@@ -29,7 +29,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -29,7 +29,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
Returns a tuple with three elements: Returns a tuple with three elements:
1) the loss, as a Variable 1) the loss
2) the sample size, which is used as the denominator for the gradient 2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training 3) logging outputs to display while training
""" """
......
...@@ -565,23 +565,23 @@ def extend_conv_spec(convolutions): ...@@ -565,23 +565,23 @@ def extend_conv_spec(convolutions):
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal(m.weight, 0, 0.1) nn.init.normal_(m.weight, 0, 0.1)
nn.init.constant(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
return m return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad): def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad):
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad) m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
nn.init.normal(m.weight, 0, 0.1) nn.init.normal_(m.weight, 0, 0.1)
nn.init.constant(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
return m return m
def Linear(in_features, out_features, dropout=0): def Linear(in_features, out_features, dropout=0):
"""Weight-normalized Linear layer (input: N x T x C)""" """Weight-normalized Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features) m = nn.Linear(in_features, out_features)
nn.init.normal(m.weight, mean=0, std=math.sqrt((1 - dropout) / in_features)) nn.init.normal_(m.weight, mean=0, std=math.sqrt((1 - dropout) / in_features))
nn.init.constant(m.bias, 0) nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m) return nn.utils.weight_norm(m)
...@@ -589,8 +589,8 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs ...@@ -589,8 +589,8 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs
"""Weight-normalized Conv1d layer optimized for decoding""" """Weight-normalized Conv1d layer optimized for decoding"""
m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs) m = LinearizedConvolution(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
nn.init.normal(m.weight, mean=0, std=std) nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant(m.bias, 0) nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m, dim=2) return nn.utils.weight_norm(m, dim=2)
...@@ -599,8 +599,8 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs): ...@@ -599,8 +599,8 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
from fairseq.modules import ConvTBC from fairseq.modules import ConvTBC
m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs) m = ConvTBC(in_channels, out_channels, kernel_size, **kwargs)
std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels)) std = math.sqrt((4 * (1.0 - dropout)) / (m.kernel_size[0] * in_channels))
nn.init.normal(m.weight, mean=0, std=std) nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant(m.bias, 0) nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m, dim=2) return nn.utils.weight_norm(m, dim=2)
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch import torch
from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -171,8 +170,8 @@ class LSTMEncoder(FairseqEncoder): ...@@ -171,8 +170,8 @@ class LSTMEncoder(FairseqEncoder):
state_size = 2 * self.num_layers, bsz, self.hidden_size state_size = 2 * self.num_layers, bsz, self.hidden_size
else: else:
state_size = self.num_layers, bsz, self.hidden_size state_size = self.num_layers, bsz, self.hidden_size
h0 = Variable(x.data.new(*state_size).zero_()) h0 = x.data.new(*state_size).zero_()
c0 = Variable(x.data.new(*state_size).zero_()) c0 = x.data.new(*state_size).zero_()
packed_outs, (final_hiddens, final_cells) = self.lstm( packed_outs, (final_hiddens, final_cells) = self.lstm(
packed_x, packed_x,
(h0, c0), (h0, c0),
...@@ -306,9 +305,9 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -306,9 +305,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
num_layers = len(self.layers) num_layers = len(self.layers)
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
prev_cells = [encoder_cells[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)]
input_feed = Variable(x.data.new(bsz, self.encoder_output_units).zero_()) input_feed = x.data.new(bsz, self.encoder_output_units).zero_()
attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_()) attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
outs = [] outs = []
for j in range(seqlen): for j in range(seqlen):
# input feeding: concatenate context vector from previous time step # input feeding: concatenate context vector from previous time step
...@@ -390,8 +389,8 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -390,8 +389,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.uniform(m.weight, -0.1, 0.1) nn.init.uniform_(m.weight, -0.1, 0.1)
nn.init.constant(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
return m return m
......
...@@ -181,7 +181,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -181,7 +181,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if not self.share_input_output_embed: if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal(self.embed_out, mean=0, std=embed_dim ** -0.5) nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# embed positions # embed positions
...@@ -363,7 +363,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -363,7 +363,7 @@ class TransformerDecoderLayer(nn.Module):
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
return m return m
...@@ -374,16 +374,16 @@ def LayerNorm(embedding_dim): ...@@ -374,16 +374,16 @@ def LayerNorm(embedding_dim):
def Linear(in_features, out_features, bias=True): def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias) m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform(m.weight) nn.init.xavier_uniform_(m.weight)
nn.init.constant(m.bias, 0.) nn.init.constant_(m.bias, 0.)
return m return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False): def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned: if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad) m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
nn.init.normal(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
else: else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings) m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings)
return m return m
......
...@@ -44,7 +44,7 @@ class AdaptiveSoftmax(nn.Module): ...@@ -44,7 +44,7 @@ class AdaptiveSoftmax(nn.Module):
def init_weights(m): def init_weights(m):
if hasattr(m, 'weight'): if hasattr(m, 'weight'):
nn.init.xavier_uniform(m.weight) nn.init.xavier_uniform_(m.weight)
self.apply(init_weights) self.apply(init_weights)
......
...@@ -11,7 +11,6 @@ import math ...@@ -11,7 +11,6 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable
from fairseq.modules.scalar_bias import scalar_bias from fairseq.modules.scalar_bias import scalar_bias
...@@ -110,14 +109,14 @@ class SingleHeadAttention(nn.Module): ...@@ -110,14 +109,14 @@ class SingleHeadAttention(nn.Module):
if mask_future_timesteps: if mask_future_timesteps:
assert query.size() == key.size(), \ assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention' 'mask_future_timesteps only applies to self-attention'
attn_weights *= Variable(torch.tril( attn_weights *= torch.tril(
attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(), attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(),
diagonal=-1, diagonal=-1,
)[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)) )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
attn_weights += Variable(torch.triu( attn_weights += torch.triu(
attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(), attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(),
diagonal=0 diagonal=0
)[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)) )[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
tgt_size = tgt_len tgt_size = tgt_len
if use_scalar_bias: if use_scalar_bias:
attn_weights = scalar_bias(attn_weights, 2) attn_weights = scalar_bias(attn_weights, 2)
......
...@@ -13,7 +13,6 @@ class GradMultiply(torch.autograd.Function): ...@@ -13,7 +13,6 @@ class GradMultiply(torch.autograd.Function):
def forward(ctx, x, scale): def forward(ctx, x, scale):
ctx.scale = scale ctx.scale = scale
res = x.new(x) res = x.new(x)
ctx.mark_shared_storage((x, res))
return res return res
@staticmethod @staticmethod
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
from fairseq import utils from fairseq import utils
...@@ -29,7 +28,7 @@ class LearnedPositionalEmbedding(nn.Embedding): ...@@ -29,7 +28,7 @@ class LearnedPositionalEmbedding(nn.Embedding):
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1)) positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else: else:
positions = utils.make_positions(input.data, self.padding_idx, self.left_pad) positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return super().forward(Variable(positions)) return super().forward(positions)
def max_positions(self): def max_positions(self):
"""Maximum number of supported positions.""" """Maximum number of supported positions."""
......
...@@ -59,8 +59,8 @@ class LinearizedConvolution(ConvTBC): ...@@ -59,8 +59,8 @@ class LinearizedConvolution(ConvTBC):
input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone() input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
# append next input # append next input
input_buffer[:, -1, :] = input[:, -1, :] input_buffer[:, -1, :] = input[:, -1, :]
input = utils.volatile_variable(input_buffer) input = input_buffer
with utils.maybe_no_grad(): with torch.no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias) output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
......
...@@ -38,11 +38,11 @@ class MultiheadAttention(nn.Module): ...@@ -38,11 +38,11 @@ class MultiheadAttention(nn.Module):
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_uniform(self.in_proj_weight) nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform(self.out_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight)
if self.in_proj_bias is not None: if self.in_proj_bias is not None:
nn.init.constant(self.in_proj_bias, 0.) nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant(self.out_proj.bias, 0.) nn.init.constant_(self.out_proj.bias, 0.)
def forward(self, query, key, value, mask_future_timesteps=False, def forward(self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, incremental_state=None, key_padding_mask=None, incremental_state=None,
......
...@@ -8,7 +8,6 @@ ...@@ -8,7 +8,6 @@
import math import math
import torch import torch
from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
from fairseq import utils from fairseq import utils
...@@ -64,14 +63,13 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -64,14 +63,13 @@ class SinusoidalPositionalEmbedding(nn.Module):
self.padding_idx, self.padding_idx,
).type_as(self.weights) ).type_as(self.weights)
self.weights = self.weights.type_as(self._float_tensor) self.weights = self.weights.type_as(self._float_tensor)
weights = Variable(self.weights)
if incremental_state is not None: if incremental_state is not None:
# positions is the same for every token when decoding a single step # positions is the same for every token when decoding a single step
return weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1) return self.weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
positions = Variable(utils.make_positions(input.data, self.padding_idx, self.left_pad)) positions = utils.make_positions(input.data, self.padding_idx, self.left_pad)
return weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1) return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1)
def max_positions(self): def max_positions(self):
"""Maximum number of supported positions.""" """Maximum number of supported positions."""
......
...@@ -66,14 +66,14 @@ class SequenceGenerator(object): ...@@ -66,14 +66,14 @@ class SequenceGenerator(object):
maxlen_b = self.maxlen maxlen_b = self.maxlen
for sample in data_itr: for sample in data_itr:
s = utils.make_variable(sample, volatile=True, cuda=cuda) s = utils.move_to_cuda(sample) if cuda else sample
if 'net_input' not in s: if 'net_input' not in s:
continue continue
input = s['net_input'] input = s['net_input']
srclen = input['src_tokens'].size(1) srclen = input['src_tokens'].size(1)
if timer is not None: if timer is not None:
timer.start() timer.start()
with utils.maybe_no_grad(): with torch.no_grad():
hypos = self.generate( hypos = self.generate(
input['src_tokens'], input['src_tokens'],
input['src_lengths'], input['src_lengths'],
...@@ -91,7 +91,7 @@ class SequenceGenerator(object): ...@@ -91,7 +91,7 @@ class SequenceGenerator(object):
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations.""" """Generate a batch of translations."""
with utils.maybe_no_grad(): with torch.no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens) return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None): def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
...@@ -492,14 +492,11 @@ class SequenceGenerator(object): ...@@ -492,14 +492,11 @@ class SequenceGenerator(object):
return finalized return finalized
def _decode(self, tokens, encoder_outs, incremental_states): def _decode(self, tokens, encoder_outs, incremental_states):
# wrap in Variable
tokens = utils.volatile_variable(tokens)
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
with utils.maybe_no_grad(): with torch.no_grad():
if incremental_states[model] is not None: if incremental_states[model] is not None:
decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model])) decoder_out = list(model.decoder(tokens, encoder_out, incremental_states[model]))
else: else:
......
...@@ -23,7 +23,7 @@ class SequenceScorer(object): ...@@ -23,7 +23,7 @@ class SequenceScorer(object):
def score_batched_itr(self, data_itr, cuda=False, timer=None): def score_batched_itr(self, data_itr, cuda=False, timer=None):
"""Iterate over a batched dataset and yield scored translations.""" """Iterate over a batched dataset and yield scored translations."""
for sample in data_itr: for sample in data_itr:
s = utils.make_variable(sample, volatile=True, cuda=cuda) s = utils.move_to_cuda(sample) if cuda else sample
if timer is not None: if timer is not None:
timer.start() timer.start()
pos_scores, attn = self.score(s) pos_scores, attn = self.score(s)
...@@ -59,7 +59,7 @@ class SequenceScorer(object): ...@@ -59,7 +59,7 @@ class SequenceScorer(object):
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model in self.models: for model in self.models:
with utils.maybe_no_grad(): with torch.no_grad():
model.eval() model.eval()
decoder_out = model.forward(**net_input) decoder_out = model.forward(**net_input)
attn = decoder_out[1] attn = decoder_out[1]
......
...@@ -10,6 +10,7 @@ Train a network across multiple GPUs. ...@@ -10,6 +10,7 @@ Train a network across multiple GPUs.
""" """
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
import contextlib
from itertools import chain from itertools import chain
import torch import torch
...@@ -112,7 +113,7 @@ class Trainer(object): ...@@ -112,7 +113,7 @@ class Trainer(object):
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
# forward and backward pass # forward and backward pass
sample = self._prepare_sample(sample, volatile=False) sample = self._prepare_sample(sample)
loss, sample_size, logging_output, oom_fwd = self._forward(sample) loss, sample_size, logging_output, oom_fwd = self._forward(sample)
oom_bwd = self._backward(loss) oom_bwd = self._backward(loss)
...@@ -191,7 +192,7 @@ class Trainer(object): ...@@ -191,7 +192,7 @@ class Trainer(object):
oom = 0 oom = 0
if sample is not None: if sample is not None:
try: try:
with utils.maybe_no_grad(eval): with torch.no_grad() if eval else contextlib.ExitStack():
# calculate loss and sample size # calculate loss and sample size
loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample) loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
logging_output.update(logging_output_) logging_output.update(logging_output_)
...@@ -276,10 +277,8 @@ class Trainer(object): ...@@ -276,10 +277,8 @@ class Trainer(object):
def valid_step(self, sample): def valid_step(self, sample):
"""Do forward pass in evaluation mode.""" """Do forward pass in evaluation mode."""
sample = self._prepare_sample(sample, volatile=True)
# forward pass # forward pass
sample = self._prepare_sample(sample)
_loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True) _loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True)
assert not oom_fwd, 'Ran out of memory during validation' assert not oom_fwd, 'Ran out of memory during validation'
...@@ -344,7 +343,7 @@ class Trainer(object): ...@@ -344,7 +343,7 @@ class Trainer(object):
"""Get the number of parameters updates.""" """Get the number of parameters updates."""
return self._num_updates return self._num_updates
def _prepare_sample(self, sample, volatile): def _prepare_sample(self, sample):
if sample is None or len(sample) == 0: if sample is None or len(sample) == 0:
return None return None
return utils.make_variable(sample, volatile=volatile, cuda=True) return utils.move_to_cuda(sample)
...@@ -6,14 +6,12 @@ ...@@ -6,14 +6,12 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
import contextlib
import logging import logging
import os import os
import re import re
import torch import torch
import traceback import traceback
from torch.autograd import Variable
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
...@@ -169,46 +167,24 @@ def _override_model_args(args, model_arg_overrides): ...@@ -169,46 +167,24 @@ def _override_model_args(args, model_arg_overrides):
return args return args
def maybe_no_grad(condition=True): def move_to_cuda(sample):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
def volatile_variable(*args, **kwargs):
if hasattr(torch, 'no_grad'):
# volatile has been deprecated, use the no_grad context manager instead
return Variable(*args, **kwargs)
else:
return Variable(*args, **kwargs, volatile=True)
def make_variable(sample, volatile=False, cuda=False):
"""Wrap input tensors in Variable class."""
if len(sample) == 0: if len(sample) == 0:
return {} return {}
def _make_variable(maybe_tensor): def _move_to_cuda(maybe_tensor):
if torch.is_tensor(maybe_tensor): if torch.is_tensor(maybe_tensor):
if cuda and torch.cuda.is_available(): return maybe_tensor.cuda()
maybe_tensor = maybe_tensor.cuda()
if volatile:
return volatile_variable(maybe_tensor)
else:
return Variable(maybe_tensor)
elif isinstance(maybe_tensor, dict): elif isinstance(maybe_tensor, dict):
return { return {
key: _make_variable(value) key: _move_to_cuda(value)
for key, value in maybe_tensor.items() for key, value in maybe_tensor.items()
} }
elif isinstance(maybe_tensor, list): elif isinstance(maybe_tensor, list):
return [_make_variable(x) for x in maybe_tensor] return [_move_to_cuda(x) for x in maybe_tensor]
else: else:
return maybe_tensor return maybe_tensor
return _make_variable(sample) return _move_to_cuda(sample)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
......
...@@ -77,16 +77,6 @@ class TestUtils(unittest.TestCase): ...@@ -77,16 +77,6 @@ class TestUtils(unittest.TestCase):
utils.make_positions(right_pad_input, pad, left_pad=False), utils.make_positions(right_pad_input, pad, left_pad=False),
) )
def test_make_variable(self):
t = [{'k': torch.rand(5, 5)}]
v = utils.make_variable(t)[0]['k']
self.assertTrue(isinstance(v, Variable))
self.assertFalse(v.data.is_cuda)
v = utils.make_variable(t, cuda=True)[0]['k']
self.assertEqual(v.data.is_cuda, torch.cuda.is_available())
def assertAlmostEqual(self, t1, t2): def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch") self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4) self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment