"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "945e5b9436736aef2c258f56b31f391e1b8ac3ef"
Commit 531cb5c3 authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating word_language_model examples to evaluate with no_grad instead of volatile

parent 0d91a65e
......@@ -134,9 +134,9 @@ def repackage_hidden(h):
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.
def get_batch(source, i, evaluation=False):
def get_batch(source, i):
seq_len = min(args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
data = Variable(source[i:i+seq_len])
target = Variable(source[i+1:i+1+seq_len].view(-1))
return data, target
......@@ -147,13 +147,14 @@ def evaluate(data_source):
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(eval_batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, evaluation=True)
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
#total loss can overflow if accumulated in fp16.
total_loss += len(data) * criterion(output_flat, targets).data.float()
hidden = repackage_hidden(hidden)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i)
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
#total loss can overflow if accumulated in fp16.
total_loss += len(data) * criterion(output_flat, targets).data.float()
hidden = repackage_hidden(hidden)
return to_python_float(total_loss) / len(data_source)
......
......@@ -149,9 +149,9 @@ def repackage_hidden(h):
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.
def get_batch(source, i, evaluation=False):
def get_batch(source, i):
seq_len = min(args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
data = Variable(source[i:i+seq_len])
target = Variable(source[i+1:i+1+seq_len].view(-1))
return data, target
......@@ -162,13 +162,14 @@ def evaluate(data_source):
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(eval_batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, evaluation=True)
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
#total loss can overflow if accumulated in fp16.
total_loss += len(data) * criterion(output_flat, targets).data.float()
hidden = repackage_hidden(hidden)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i)
output, hidden = model(data, hidden)
output_flat = output.view(-1, ntokens)
#total loss can overflow if accumulated in fp16.
total_loss += len(data) * criterion(output_flat, targets).data.float()
hidden = repackage_hidden(hidden)
return to_python_float(total_loss) / len(data_source)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment