Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
e89329d6
Commit
e89329d6
authored
Jun 12, 2018
by
Myle Ott
Browse files
Updates for latest PyTorch
parent
ff68a9ef
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
59 additions
and
103 deletions
+59
-103
fairseq/criterions/adaptive_loss.py
fairseq/criterions/adaptive_loss.py
+1
-1
fairseq/criterions/cross_entropy.py
fairseq/criterions/cross_entropy.py
+1
-1
fairseq/criterions/fairseq_criterion.py
fairseq/criterions/fairseq_criterion.py
+1
-1
fairseq/criterions/label_smoothed_cross_entropy.py
fairseq/criterions/label_smoothed_cross_entropy.py
+1
-1
fairseq/models/fconv.py
fairseq/models/fconv.py
+10
-10
fairseq/models/lstm.py
fairseq/models/lstm.py
+6
-7
fairseq/models/transformer.py
fairseq/models/transformer.py
+6
-6
fairseq/modules/adaptive_softmax.py
fairseq/modules/adaptive_softmax.py
+1
-1
fairseq/modules/downsampled_multihead_attention.py
fairseq/modules/downsampled_multihead_attention.py
+4
-5
fairseq/modules/grad_multiply.py
fairseq/modules/grad_multiply.py
+0
-1
fairseq/modules/learned_positional_embedding.py
fairseq/modules/learned_positional_embedding.py
+1
-2
fairseq/modules/linearized_convolution.py
fairseq/modules/linearized_convolution.py
+2
-2
fairseq/modules/multihead_attention.py
fairseq/modules/multihead_attention.py
+4
-4
fairseq/modules/sinusoidal_positional_embedding.py
fairseq/modules/sinusoidal_positional_embedding.py
+3
-5
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+4
-7
fairseq/sequence_scorer.py
fairseq/sequence_scorer.py
+2
-2
fairseq/trainer.py
fairseq/trainer.py
+6
-7
fairseq/utils.py
fairseq/utils.py
+6
-30
tests/test_utils.py
tests/test_utils.py
+0
-10
No files found.
fairseq/criterions/adaptive_loss.py
View file @
e89329d6
...
...
@@ -26,7 +26,7 @@ class AdaptiveLoss(FairseqCriterion):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
, as a Variable
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
...
...
fairseq/criterions/cross_entropy.py
View file @
e89329d6
...
...
@@ -23,7 +23,7 @@ class CrossEntropyCriterion(FairseqCriterion):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
, as a Variable
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
...
...
fairseq/criterions/fairseq_criterion.py
View file @
e89329d6
...
...
@@ -24,7 +24,7 @@ class FairseqCriterion(_Loss):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
, as a Variable
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
...
...
fairseq/criterions/label_smoothed_cross_entropy.py
View file @
e89329d6
...
...
@@ -29,7 +29,7 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
, as a Variable
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
...
...
fairseq/models/fconv.py
View file @
e89329d6
...
...
@@ -565,23 +565,23 @@ def extend_conv_spec(convolutions):
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
nn
.
init
.
normal
(
m
.
weight
,
0
,
0.1
)
nn
.
init
.
constant
(
m
.
weight
[
padding_idx
],
0
)
nn
.
init
.
normal
_
(
m
.
weight
,
0
,
0.1
)
nn
.
init
.
constant
_
(
m
.
weight
[
padding_idx
],
0
)
return
m
def
PositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
m
=
LearnedPositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
)
nn
.
init
.
normal
(
m
.
weight
,
0
,
0.1
)
nn
.
init
.
constant
(
m
.
weight
[
padding_idx
],
0
)
nn
.
init
.
normal
_
(
m
.
weight
,
0
,
0.1
)
nn
.
init
.
constant
_
(
m
.
weight
[
padding_idx
],
0
)
return
m
def
Linear
(
in_features
,
out_features
,
dropout
=
0
):
"""Weight-normalized Linear layer (input: N x T x C)"""
m
=
nn
.
Linear
(
in_features
,
out_features
)
nn
.
init
.
normal
(
m
.
weight
,
mean
=
0
,
std
=
math
.
sqrt
((
1
-
dropout
)
/
in_features
))
nn
.
init
.
constant
(
m
.
bias
,
0
)
nn
.
init
.
normal
_
(
m
.
weight
,
mean
=
0
,
std
=
math
.
sqrt
((
1
-
dropout
)
/
in_features
))
nn
.
init
.
constant
_
(
m
.
bias
,
0
)
return
nn
.
utils
.
weight_norm
(
m
)
...
...
@@ -589,8 +589,8 @@ def LinearizedConv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs
"""Weight-normalized Conv1d layer optimized for decoding"""
m
=
LinearizedConvolution
(
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
)
std
=
math
.
sqrt
((
4
*
(
1.0
-
dropout
))
/
(
m
.
kernel_size
[
0
]
*
in_channels
))
nn
.
init
.
normal
(
m
.
weight
,
mean
=
0
,
std
=
std
)
nn
.
init
.
constant
(
m
.
bias
,
0
)
nn
.
init
.
normal
_
(
m
.
weight
,
mean
=
0
,
std
=
std
)
nn
.
init
.
constant
_
(
m
.
bias
,
0
)
return
nn
.
utils
.
weight_norm
(
m
,
dim
=
2
)
...
...
@@ -599,8 +599,8 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
from
fairseq.modules
import
ConvTBC
m
=
ConvTBC
(
in_channels
,
out_channels
,
kernel_size
,
**
kwargs
)
std
=
math
.
sqrt
((
4
*
(
1.0
-
dropout
))
/
(
m
.
kernel_size
[
0
]
*
in_channels
))
nn
.
init
.
normal
(
m
.
weight
,
mean
=
0
,
std
=
std
)
nn
.
init
.
constant
(
m
.
bias
,
0
)
nn
.
init
.
normal
_
(
m
.
weight
,
mean
=
0
,
std
=
std
)
nn
.
init
.
constant
_
(
m
.
bias
,
0
)
return
nn
.
utils
.
weight_norm
(
m
,
dim
=
2
)
...
...
fairseq/models/lstm.py
View file @
e89329d6
...
...
@@ -6,7 +6,6 @@
# can be found in the PATENTS file in the same directory.
import
torch
from
torch.autograd
import
Variable
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -171,8 +170,8 @@ class LSTMEncoder(FairseqEncoder):
state_size
=
2
*
self
.
num_layers
,
bsz
,
self
.
hidden_size
else
:
state_size
=
self
.
num_layers
,
bsz
,
self
.
hidden_size
h0
=
Variable
(
x
.
data
.
new
(
*
state_size
).
zero_
()
)
c0
=
Variable
(
x
.
data
.
new
(
*
state_size
).
zero_
()
)
h0
=
x
.
data
.
new
(
*
state_size
).
zero_
()
c0
=
x
.
data
.
new
(
*
state_size
).
zero_
()
packed_outs
,
(
final_hiddens
,
final_cells
)
=
self
.
lstm
(
packed_x
,
(
h0
,
c0
),
...
...
@@ -306,9 +305,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
num_layers
=
len
(
self
.
layers
)
prev_hiddens
=
[
encoder_hiddens
[
i
]
for
i
in
range
(
num_layers
)]
prev_cells
=
[
encoder_cells
[
i
]
for
i
in
range
(
num_layers
)]
input_feed
=
Variable
(
x
.
data
.
new
(
bsz
,
self
.
encoder_output_units
).
zero_
()
)
input_feed
=
x
.
data
.
new
(
bsz
,
self
.
encoder_output_units
).
zero_
()
attn_scores
=
Variable
(
x
.
data
.
new
(
srclen
,
seqlen
,
bsz
).
zero_
()
)
attn_scores
=
x
.
data
.
new
(
srclen
,
seqlen
,
bsz
).
zero_
()
outs
=
[]
for
j
in
range
(
seqlen
):
# input feeding: concatenate context vector from previous time step
...
...
@@ -390,8 +389,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
nn
.
init
.
uniform
(
m
.
weight
,
-
0.1
,
0.1
)
nn
.
init
.
constant
(
m
.
weight
[
padding_idx
],
0
)
nn
.
init
.
uniform
_
(
m
.
weight
,
-
0.1
,
0.1
)
nn
.
init
.
constant
_
(
m
.
weight
[
padding_idx
],
0
)
return
m
...
...
fairseq/models/transformer.py
View file @
e89329d6
...
...
@@ -181,7 +181,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if
not
self
.
share_input_output_embed
:
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
embed_dim
))
nn
.
init
.
normal
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
nn
.
init
.
normal
_
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
# embed positions
...
...
@@ -363,7 +363,7 @@ class TransformerDecoderLayer(nn.Module):
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
nn
.
init
.
normal
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
normal
_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
return
m
...
...
@@ -374,16 +374,16 @@ def LayerNorm(embedding_dim):
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform
(
m
.
weight
)
nn
.
init
.
constant
(
m
.
bias
,
0.
)
nn
.
init
.
xavier_uniform
_
(
m
.
weight
)
nn
.
init
.
constant
_
(
m
.
bias
,
0.
)
return
m
def
PositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
,
learned
=
False
):
if
learned
:
m
=
LearnedPositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
)
nn
.
init
.
normal
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
constant
(
m
.
weight
[
padding_idx
],
0
)
nn
.
init
.
normal
_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
constant
_
(
m
.
weight
[
padding_idx
],
0
)
else
:
m
=
SinusoidalPositionalEmbedding
(
embedding_dim
,
padding_idx
,
left_pad
,
init_size
=
num_embeddings
)
return
m
...
...
fairseq/modules/adaptive_softmax.py
View file @
e89329d6
...
...
@@ -44,7 +44,7 @@ class AdaptiveSoftmax(nn.Module):
def
init_weights
(
m
):
if
hasattr
(
m
,
'weight'
):
nn
.
init
.
xavier_uniform
(
m
.
weight
)
nn
.
init
.
xavier_uniform
_
(
m
.
weight
)
self
.
apply
(
init_weights
)
...
...
fairseq/modules/downsampled_multihead_attention.py
View file @
e89329d6
...
...
@@ -11,7 +11,6 @@ import math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.autograd
import
Variable
from
fairseq.modules.scalar_bias
import
scalar_bias
...
...
@@ -110,14 +109,14 @@ class SingleHeadAttention(nn.Module):
if
mask_future_timesteps
:
assert
query
.
size
()
==
key
.
size
(),
\
'mask_future_timesteps only applies to self-attention'
attn_weights
*=
Variable
(
torch
.
tril
(
attn_weights
*=
torch
.
tril
(
attn_weights
.
data
.
new
([
1
]).
expand
(
tgt_len
,
tgt_len
).
clone
(),
diagonal
=-
1
,
)[:,
::
self
.
head_index
+
1
if
self
.
downsample
else
1
].
unsqueeze
(
0
)
)
attn_weights
+=
Variable
(
torch
.
triu
(
)[:,
::
self
.
head_index
+
1
if
self
.
downsample
else
1
].
unsqueeze
(
0
)
attn_weights
+=
torch
.
triu
(
attn_weights
.
data
.
new
([
-
math
.
inf
]).
expand
(
tgt_len
,
tgt_len
).
clone
(),
diagonal
=
0
)[:,
::
self
.
head_index
+
1
if
self
.
downsample
else
1
].
unsqueeze
(
0
)
)
)[:,
::
self
.
head_index
+
1
if
self
.
downsample
else
1
].
unsqueeze
(
0
)
tgt_size
=
tgt_len
if
use_scalar_bias
:
attn_weights
=
scalar_bias
(
attn_weights
,
2
)
...
...
fairseq/modules/grad_multiply.py
View file @
e89329d6
...
...
@@ -13,7 +13,6 @@ class GradMultiply(torch.autograd.Function):
def
forward
(
ctx
,
x
,
scale
):
ctx
.
scale
=
scale
res
=
x
.
new
(
x
)
ctx
.
mark_shared_storage
((
x
,
res
))
return
res
@
staticmethod
...
...
fairseq/modules/learned_positional_embedding.py
View file @
e89329d6
...
...
@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
torch.autograd
import
Variable
import
torch.nn
as
nn
from
fairseq
import
utils
...
...
@@ -29,7 +28,7 @@ class LearnedPositionalEmbedding(nn.Embedding):
positions
=
input
.
data
.
new
(
1
,
1
).
fill_
(
self
.
padding_idx
+
input
.
size
(
1
))
else
:
positions
=
utils
.
make_positions
(
input
.
data
,
self
.
padding_idx
,
self
.
left_pad
)
return
super
().
forward
(
Variable
(
positions
)
)
return
super
().
forward
(
positions
)
def
max_positions
(
self
):
"""Maximum number of supported positions."""
...
...
fairseq/modules/linearized_convolution.py
View file @
e89329d6
...
...
@@ -59,8 +59,8 @@ class LinearizedConvolution(ConvTBC):
input_buffer
[:,
:
-
1
,
:]
=
input_buffer
[:,
1
:,
:].
clone
()
# append next input
input_buffer
[:,
-
1
,
:]
=
input
[:,
-
1
,
:]
input
=
utils
.
volatile_variable
(
input_buffer
)
with
utils
.
maybe_
no_grad
():
input
=
input_buffer
with
torch
.
no_grad
():
output
=
F
.
linear
(
input
.
view
(
bsz
,
-
1
),
weight
,
self
.
bias
)
return
output
.
view
(
bsz
,
1
,
-
1
)
...
...
fairseq/modules/multihead_attention.py
View file @
e89329d6
...
...
@@ -38,11 +38,11 @@ class MultiheadAttention(nn.Module):
self
.
reset_parameters
()
def
reset_parameters
(
self
):
nn
.
init
.
xavier_uniform
(
self
.
in_proj_weight
)
nn
.
init
.
xavier_uniform
(
self
.
out_proj
.
weight
)
nn
.
init
.
xavier_uniform
_
(
self
.
in_proj_weight
)
nn
.
init
.
xavier_uniform
_
(
self
.
out_proj
.
weight
)
if
self
.
in_proj_bias
is
not
None
:
nn
.
init
.
constant
(
self
.
in_proj_bias
,
0.
)
nn
.
init
.
constant
(
self
.
out_proj
.
bias
,
0.
)
nn
.
init
.
constant
_
(
self
.
in_proj_bias
,
0.
)
nn
.
init
.
constant
_
(
self
.
out_proj
.
bias
,
0.
)
def
forward
(
self
,
query
,
key
,
value
,
mask_future_timesteps
=
False
,
key_padding_mask
=
None
,
incremental_state
=
None
,
...
...
fairseq/modules/sinusoidal_positional_embedding.py
View file @
e89329d6
...
...
@@ -8,7 +8,6 @@
import
math
import
torch
from
torch.autograd
import
Variable
import
torch.nn
as
nn
from
fairseq
import
utils
...
...
@@ -64,14 +63,13 @@ class SinusoidalPositionalEmbedding(nn.Module):
self
.
padding_idx
,
).
type_as
(
self
.
weights
)
self
.
weights
=
self
.
weights
.
type_as
(
self
.
_float_tensor
)
weights
=
Variable
(
self
.
weights
)
if
incremental_state
is
not
None
:
# positions is the same for every token when decoding a single step
return
weights
[
self
.
padding_idx
+
seq_len
,
:].
expand
(
bsz
,
1
,
-
1
)
return
self
.
weights
[
self
.
padding_idx
+
seq_len
,
:].
expand
(
bsz
,
1
,
-
1
)
positions
=
Variable
(
utils
.
make_positions
(
input
.
data
,
self
.
padding_idx
,
self
.
left_pad
)
)
return
weights
.
index_select
(
0
,
positions
.
view
(
-
1
)).
view
(
bsz
,
seq_len
,
-
1
)
positions
=
utils
.
make_positions
(
input
.
data
,
self
.
padding_idx
,
self
.
left_pad
)
return
self
.
weights
.
index_select
(
0
,
positions
.
view
(
-
1
)).
view
(
bsz
,
seq_len
,
-
1
)
def
max_positions
(
self
):
"""Maximum number of supported positions."""
...
...
fairseq/sequence_generator.py
View file @
e89329d6
...
...
@@ -66,14 +66,14 @@ class SequenceGenerator(object):
maxlen_b
=
self
.
maxlen
for
sample
in
data_itr
:
s
=
utils
.
m
ake_variable
(
sample
,
volatile
=
True
,
cuda
=
cuda
)
s
=
utils
.
m
ove_to_cuda
(
sample
)
if
cuda
else
sample
if
'net_input'
not
in
s
:
continue
input
=
s
[
'net_input'
]
srclen
=
input
[
'src_tokens'
].
size
(
1
)
if
timer
is
not
None
:
timer
.
start
()
with
utils
.
maybe_
no_grad
():
with
torch
.
no_grad
():
hypos
=
self
.
generate
(
input
[
'src_tokens'
],
input
[
'src_lengths'
],
...
...
@@ -91,7 +91,7 @@ class SequenceGenerator(object):
def
generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
"""Generate a batch of translations."""
with
utils
.
maybe_
no_grad
():
with
torch
.
no_grad
():
return
self
.
_generate
(
src_tokens
,
src_lengths
,
beam_size
,
maxlen
,
prefix_tokens
)
def
_generate
(
self
,
src_tokens
,
src_lengths
,
beam_size
=
None
,
maxlen
=
None
,
prefix_tokens
=
None
):
...
...
@@ -492,14 +492,11 @@ class SequenceGenerator(object):
return
finalized
def
_decode
(
self
,
tokens
,
encoder_outs
,
incremental_states
):
# wrap in Variable
tokens
=
utils
.
volatile_variable
(
tokens
)
avg_probs
=
None
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
with
utils
.
maybe_
no_grad
():
with
torch
.
no_grad
():
if
incremental_states
[
model
]
is
not
None
:
decoder_out
=
list
(
model
.
decoder
(
tokens
,
encoder_out
,
incremental_states
[
model
]))
else
:
...
...
fairseq/sequence_scorer.py
View file @
e89329d6
...
...
@@ -23,7 +23,7 @@ class SequenceScorer(object):
def
score_batched_itr
(
self
,
data_itr
,
cuda
=
False
,
timer
=
None
):
"""Iterate over a batched dataset and yield scored translations."""
for
sample
in
data_itr
:
s
=
utils
.
m
ake_variable
(
sample
,
volatile
=
True
,
cuda
=
cuda
)
s
=
utils
.
m
ove_to_cuda
(
sample
)
if
cuda
else
sample
if
timer
is
not
None
:
timer
.
start
()
pos_scores
,
attn
=
self
.
score
(
s
)
...
...
@@ -59,7 +59,7 @@ class SequenceScorer(object):
avg_probs
=
None
avg_attn
=
None
for
model
in
self
.
models
:
with
utils
.
maybe_
no_grad
():
with
torch
.
no_grad
():
model
.
eval
()
decoder_out
=
model
.
forward
(
**
net_input
)
attn
=
decoder_out
[
1
]
...
...
fairseq/trainer.py
View file @
e89329d6
...
...
@@ -10,6 +10,7 @@ Train a network across multiple GPUs.
"""
from
collections
import
defaultdict
,
OrderedDict
import
contextlib
from
itertools
import
chain
import
torch
...
...
@@ -112,7 +113,7 @@ class Trainer(object):
torch
.
cuda
.
manual_seed
(
seed
)
# forward and backward pass
sample
=
self
.
_prepare_sample
(
sample
,
volatile
=
False
)
sample
=
self
.
_prepare_sample
(
sample
)
loss
,
sample_size
,
logging_output
,
oom_fwd
=
self
.
_forward
(
sample
)
oom_bwd
=
self
.
_backward
(
loss
)
...
...
@@ -191,7 +192,7 @@ class Trainer(object):
oom
=
0
if
sample
is
not
None
:
try
:
with
utils
.
maybe_no_grad
(
eval
):
with
torch
.
no_grad
()
if
eval
else
contextlib
.
ExitStack
(
):
# calculate loss and sample size
loss
,
sample_size
,
logging_output_
=
self
.
task
.
get_loss
(
self
.
model
,
self
.
criterion
,
sample
)
logging_output
.
update
(
logging_output_
)
...
...
@@ -276,10 +277,8 @@ class Trainer(object):
def
valid_step
(
self
,
sample
):
"""Do forward pass in evaluation mode."""
sample
=
self
.
_prepare_sample
(
sample
,
volatile
=
True
)
# forward pass
sample
=
self
.
_prepare_sample
(
sample
)
_loss
,
sample_size
,
logging_output
,
oom_fwd
=
self
.
_forward
(
sample
,
eval
=
True
)
assert
not
oom_fwd
,
'Ran out of memory during validation'
...
...
@@ -344,7 +343,7 @@ class Trainer(object):
"""Get the number of parameters updates."""
return
self
.
_num_updates
def
_prepare_sample
(
self
,
sample
,
volatile
):
def
_prepare_sample
(
self
,
sample
):
if
sample
is
None
or
len
(
sample
)
==
0
:
return
None
return
utils
.
m
ake_variable
(
sample
,
volatile
=
volatile
,
cuda
=
Tru
e
)
return
utils
.
m
ove_to_cuda
(
sampl
e
)
fairseq/utils.py
View file @
e89329d6
...
...
@@ -6,14 +6,12 @@
# can be found in the PATENTS file in the same directory.
from
collections
import
defaultdict
,
OrderedDict
import
contextlib
import
logging
import
os
import
re
import
torch
import
traceback
from
torch.autograd
import
Variable
from
torch.serialization
import
default_restore_location
...
...
@@ -169,46 +167,24 @@ def _override_model_args(args, model_arg_overrides):
return
args
def
maybe_no_grad
(
condition
=
True
):
if
hasattr
(
torch
,
'no_grad'
)
and
condition
:
return
torch
.
no_grad
()
# no-op context manager
return
contextlib
.
ExitStack
()
def
volatile_variable
(
*
args
,
**
kwargs
):
if
hasattr
(
torch
,
'no_grad'
):
# volatile has been deprecated, use the no_grad context manager instead
return
Variable
(
*
args
,
**
kwargs
)
else
:
return
Variable
(
*
args
,
**
kwargs
,
volatile
=
True
)
def
make_variable
(
sample
,
volatile
=
False
,
cuda
=
False
):
"""Wrap input tensors in Variable class."""
def
move_to_cuda
(
sample
):
if
len
(
sample
)
==
0
:
return
{}
def
_m
ake_variable
(
maybe_tensor
):
def
_m
ove_to_cuda
(
maybe_tensor
):
if
torch
.
is_tensor
(
maybe_tensor
):
if
cuda
and
torch
.
cuda
.
is_available
():
maybe_tensor
=
maybe_tensor
.
cuda
()
if
volatile
:
return
volatile_variable
(
maybe_tensor
)
else
:
return
Variable
(
maybe_tensor
)
return
maybe_tensor
.
cuda
()
elif
isinstance
(
maybe_tensor
,
dict
):
return
{
key
:
_m
ake_variable
(
value
)
key
:
_m
ove_to_cuda
(
value
)
for
key
,
value
in
maybe_tensor
.
items
()
}
elif
isinstance
(
maybe_tensor
,
list
):
return
[
_m
ake_variable
(
x
)
for
x
in
maybe_tensor
]
return
[
_m
ove_to_cuda
(
x
)
for
x
in
maybe_tensor
]
else
:
return
maybe_tensor
return
_m
ake_variable
(
sample
)
return
_m
ove_to_cuda
(
sample
)
INCREMENTAL_STATE_INSTANCE_ID
=
defaultdict
(
lambda
:
0
)
...
...
tests/test_utils.py
View file @
e89329d6
...
...
@@ -77,16 +77,6 @@ class TestUtils(unittest.TestCase):
utils
.
make_positions
(
right_pad_input
,
pad
,
left_pad
=
False
),
)
def
test_make_variable
(
self
):
t
=
[{
'k'
:
torch
.
rand
(
5
,
5
)}]
v
=
utils
.
make_variable
(
t
)[
0
][
'k'
]
self
.
assertTrue
(
isinstance
(
v
,
Variable
))
self
.
assertFalse
(
v
.
data
.
is_cuda
)
v
=
utils
.
make_variable
(
t
,
cuda
=
True
)[
0
][
'k'
]
self
.
assertEqual
(
v
.
data
.
is_cuda
,
torch
.
cuda
.
is_available
())
def
assertAlmostEqual
(
self
,
t1
,
t2
):
self
.
assertEqual
(
t1
.
size
(),
t2
.
size
(),
"size mismatch"
)
self
.
assertLess
(
utils
.
item
((
t1
-
t2
).
abs
().
max
()),
1e-4
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment