Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
907ca927
Commit
907ca927
authored
Jan 12, 2018
by
Myle Ott
Browse files
Better support for torch.no_grad (since volatile is deprecated)
parent
0b84ab19
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
8 deletions
+12
-8
fairseq/modules/linearized_convolution.py
fairseq/modules/linearized_convolution.py
+2
-1
fairseq/sequence_generator.py
fairseq/sequence_generator.py
+5
-3
fairseq/utils.py
fairseq/utils.py
+5
-4
No files found.
fairseq/modules/linearized_convolution.py
View file @
907ca927
...
@@ -69,6 +69,7 @@ class LinearizedConvolution(ConvTBC):
...
@@ -69,6 +69,7 @@ class LinearizedConvolution(ConvTBC):
# append next input
# append next input
self
.
input_buffer
[:,
-
1
,
:]
=
input
[:,
-
1
,
:]
self
.
input_buffer
[:,
-
1
,
:]
=
input
[:,
-
1
,
:]
input
=
utils
.
volatile_variable
(
self
.
input_buffer
)
input
=
utils
.
volatile_variable
(
self
.
input_buffer
)
with
utils
.
maybe_no_grad
():
output
=
F
.
linear
(
input
.
view
(
bsz
,
-
1
),
weight
,
self
.
bias
)
output
=
F
.
linear
(
input
.
view
(
bsz
,
-
1
),
weight
,
self
.
bias
)
return
output
.
view
(
bsz
,
1
,
-
1
)
return
output
.
view
(
bsz
,
1
,
-
1
)
...
...
fairseq/sequence_generator.py
View file @
907ca927
...
@@ -71,6 +71,7 @@ class SequenceGenerator(object):
...
@@ -71,6 +71,7 @@ class SequenceGenerator(object):
srclen
=
input
[
'src_tokens'
].
size
(
1
)
srclen
=
input
[
'src_tokens'
].
size
(
1
)
if
timer
is
not
None
:
if
timer
is
not
None
:
timer
.
start
()
timer
.
start
()
with
utils
.
maybe_no_grad
():
hypos
=
self
.
generate
(
input
[
'src_tokens'
],
beam_size
=
beam_size
,
hypos
=
self
.
generate
(
input
[
'src_tokens'
],
beam_size
=
beam_size
,
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
))
maxlen
=
int
(
maxlen_a
*
srclen
+
maxlen_b
))
if
timer
is
not
None
:
if
timer
is
not
None
:
...
@@ -327,6 +328,7 @@ class SequenceGenerator(object):
...
@@ -327,6 +328,7 @@ class SequenceGenerator(object):
avg_probs
=
None
avg_probs
=
None
avg_attn
=
None
avg_attn
=
None
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
for
model
,
encoder_out
in
zip
(
self
.
models
,
encoder_outs
):
with
utils
.
maybe_no_grad
():
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
)
decoder_out
,
attn
=
model
.
decoder
(
tokens
,
encoder_out
)
probs
=
model
.
get_normalized_probs
(
decoder_out
[:,
-
1
,
:],
log_probs
=
False
).
data
probs
=
model
.
get_normalized_probs
(
decoder_out
[:,
-
1
,
:],
log_probs
=
False
).
data
if
avg_probs
is
None
:
if
avg_probs
is
None
:
...
...
fairseq/utils.py
View file @
907ca927
...
@@ -176,7 +176,7 @@ def _upgrade_args(args):
...
@@ -176,7 +176,7 @@ def _upgrade_args(args):
return
args
return
args
def
maybe_no_grad
(
condition
):
def
maybe_no_grad
(
condition
=
True
):
if
hasattr
(
torch
,
'no_grad'
)
and
condition
:
if
hasattr
(
torch
,
'no_grad'
)
and
condition
:
return
torch
.
no_grad
()
return
torch
.
no_grad
()
# no-op context manager
# no-op context manager
...
@@ -185,8 +185,9 @@ def maybe_no_grad(condition):
...
@@ -185,8 +185,9 @@ def maybe_no_grad(condition):
def
volatile_variable
(
*
args
,
**
kwargs
):
def
volatile_variable
(
*
args
,
**
kwargs
):
if
hasattr
(
torch
,
'no_grad'
):
if
hasattr
(
torch
,
'no_grad'
):
with
torch
.
no_grad
():
# volatile has been deprecated, use the no_grad context manager instead
return
Variable
(
*
args
,
**
kwargs
)
return
Variable
(
*
args
,
**
kwargs
)
else
:
return
Variable
(
*
args
,
**
kwargs
,
volatile
=
True
)
return
Variable
(
*
args
,
**
kwargs
,
volatile
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment