Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
531cb5c3
Commit
531cb5c3
authored
May 07, 2018
by
Michael Carilli
Browse files
Updating word_language_model examples to evaluate with no_grad instead of volatile
parent
0d91a65e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
18 deletions
+20
-18
examples/word_language_model/main.py
examples/word_language_model/main.py
+10
-9
examples/word_language_model/main_fp16_optimizer.py
examples/word_language_model/main_fp16_optimizer.py
+10
-9
No files found.
examples/word_language_model/main.py
View file @
531cb5c3
...
...
@@ -134,9 +134,9 @@ def repackage_hidden(h):
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.
def
get_batch
(
source
,
i
,
evaluation
=
False
):
def
get_batch
(
source
,
i
):
seq_len
=
min
(
args
.
bptt
,
len
(
source
)
-
1
-
i
)
data
=
Variable
(
source
[
i
:
i
+
seq_len
]
,
volatile
=
evaluation
)
data
=
Variable
(
source
[
i
:
i
+
seq_len
])
target
=
Variable
(
source
[
i
+
1
:
i
+
1
+
seq_len
].
view
(
-
1
))
return
data
,
target
...
...
@@ -147,13 +147,14 @@ def evaluate(data_source):
total_loss
=
0
ntokens
=
len
(
corpus
.
dictionary
)
hidden
=
model
.
init_hidden
(
eval_batch_size
)
for
i
in
range
(
0
,
data_source
.
size
(
0
)
-
1
,
args
.
bptt
):
data
,
targets
=
get_batch
(
data_source
,
i
,
evaluation
=
True
)
output
,
hidden
=
model
(
data
,
hidden
)
output_flat
=
output
.
view
(
-
1
,
ntokens
)
#total loss can overflow if accumulated in fp16.
total_loss
+=
len
(
data
)
*
criterion
(
output_flat
,
targets
).
data
.
float
()
hidden
=
repackage_hidden
(
hidden
)
with
torch
.
no_grad
():
for
i
in
range
(
0
,
data_source
.
size
(
0
)
-
1
,
args
.
bptt
):
data
,
targets
=
get_batch
(
data_source
,
i
)
output
,
hidden
=
model
(
data
,
hidden
)
output_flat
=
output
.
view
(
-
1
,
ntokens
)
#total loss can overflow if accumulated in fp16.
total_loss
+=
len
(
data
)
*
criterion
(
output_flat
,
targets
).
data
.
float
()
hidden
=
repackage_hidden
(
hidden
)
return
to_python_float
(
total_loss
)
/
len
(
data_source
)
...
...
examples/word_language_model/main_fp16_optimizer.py
View file @
531cb5c3
...
...
@@ -149,9 +149,9 @@ def repackage_hidden(h):
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.
def
get_batch
(
source
,
i
,
evaluation
=
False
):
def
get_batch
(
source
,
i
):
seq_len
=
min
(
args
.
bptt
,
len
(
source
)
-
1
-
i
)
data
=
Variable
(
source
[
i
:
i
+
seq_len
]
,
volatile
=
evaluation
)
data
=
Variable
(
source
[
i
:
i
+
seq_len
])
target
=
Variable
(
source
[
i
+
1
:
i
+
1
+
seq_len
].
view
(
-
1
))
return
data
,
target
...
...
@@ -162,13 +162,14 @@ def evaluate(data_source):
total_loss
=
0
ntokens
=
len
(
corpus
.
dictionary
)
hidden
=
model
.
init_hidden
(
eval_batch_size
)
for
i
in
range
(
0
,
data_source
.
size
(
0
)
-
1
,
args
.
bptt
):
data
,
targets
=
get_batch
(
data_source
,
i
,
evaluation
=
True
)
output
,
hidden
=
model
(
data
,
hidden
)
output_flat
=
output
.
view
(
-
1
,
ntokens
)
#total loss can overflow if accumulated in fp16.
total_loss
+=
len
(
data
)
*
criterion
(
output_flat
,
targets
).
data
.
float
()
hidden
=
repackage_hidden
(
hidden
)
with
torch
.
no_grad
():
for
i
in
range
(
0
,
data_source
.
size
(
0
)
-
1
,
args
.
bptt
):
data
,
targets
=
get_batch
(
data_source
,
i
)
output
,
hidden
=
model
(
data
,
hidden
)
output_flat
=
output
.
view
(
-
1
,
ntokens
)
#total loss can overflow if accumulated in fp16.
total_loss
+=
len
(
data
)
*
criterion
(
output_flat
,
targets
).
data
.
float
()
hidden
=
repackage_hidden
(
hidden
)
return
to_python_float
(
total_loss
)
/
len
(
data_source
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment