Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
7da4e062
Commit
7da4e062
authored
Dec 22, 2017
by
Myle Ott
Browse files
Support deprecation of volatile Variables in latest PyTorch
parent
5637d54e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
14 deletions
+25
-14
fairseq/multiprocessing_trainer.py
fairseq/multiprocessing_trainer.py
+15
-14
fairseq/utils.py
fairseq/utils.py
+8
-0
generate.py
generate.py
+2
-0
No files found.
fairseq/multiprocessing_trainer.py
View file @
7da4e062
...
@@ -227,20 +227,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
...
@@ -227,20 +227,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self
.
model
.
train
()
self
.
model
.
train
()
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
sample_size
,
logging_output
,
oom
=
0
,
{},
False
with
utils
.
maybe_no_grad
(
eval
):
if
self
.
_sample
is
not
None
:
sample_size
,
logging_output
,
oom
=
0
,
{},
False
try
:
if
self
.
_sample
is
not
None
:
# calculate loss and sample size
try
:
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
# calculate loss and sample size
except
RuntimeError
as
e
:
self
.
loss
,
sample_size
,
logging_output
=
self
.
criterion
(
self
.
model
,
self
.
_sample
)
if
not
eval
and
'out of memory'
in
str
(
e
):
except
RuntimeError
as
e
:
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
if
not
eval
and
'out of memory'
in
str
(
e
):
oom
=
True
print
(
'| WARNING: ran out of memory on GPU #{}, skipping batch'
.
format
(
device_id
))
self
.
loss
=
None
oom
=
True
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
self
.
loss
=
None
torch
.
cuda
.
empty_cache
()
if
hasattr
(
torch
.
cuda
,
'empty_cache'
):
else
:
torch
.
cuda
.
empty_cache
()
raise
e
else
:
raise
e
return
sample_size
,
logging_output
,
oom
return
sample_size
,
logging_output
,
oom
...
...
fairseq/utils.py
View file @
7da4e062
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
# can be found in the PATENTS file in the same directory.
#
#
import
contextlib
import
logging
import
logging
import
os
import
os
import
torch
import
torch
...
@@ -244,3 +245,10 @@ def rstrip_pad(tensor, pad):
...
@@ -244,3 +245,10 @@ def rstrip_pad(tensor, pad):
if
strip
>
0
:
if
strip
>
0
:
return
tensor
[:
-
strip
]
return
tensor
[:
-
strip
]
return
tensor
return
tensor
def
maybe_no_grad
(
condition
):
if
hasattr
(
torch
,
'no_grad'
)
and
condition
:
return
torch
.
no_grad
()
# no-op context manager
return
contextlib
.
ExitStack
()
generate.py
View file @
7da4e062
...
@@ -35,6 +35,8 @@ def main():
...
@@ -35,6 +35,8 @@ def main():
print
(
args
)
print
(
args
)
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
use_cuda
=
torch
.
cuda
.
is_available
()
and
not
args
.
cpu
if
hasattr
(
torch
,
'set_grad_enabled'
):
torch
.
set_grad_enabled
(
False
)
# Load dataset
# Load dataset
if
args
.
replace_unk
is
None
:
if
args
.
replace_unk
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment