".github/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "332760d4b300f00a0d862e3cfe1495db3b1a14f9"
Commit 531cb5c3 authored by Michael Carilli's avatar Michael Carilli
Browse files

Updating word_language_model examples to evaluate with no_grad instead of volatile

parent 0d91a65e
...@@ -134,9 +134,9 @@ def repackage_hidden(h): ...@@ -134,9 +134,9 @@ def repackage_hidden(h):
# by the batchify function. The chunks are along dimension 0, corresponding # by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM. # to the seq_len dimension in the LSTM.
def get_batch(source, i, evaluation=False): def get_batch(source, i):
seq_len = min(args.bptt, len(source) - 1 - i) seq_len = min(args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation) data = Variable(source[i:i+seq_len])
target = Variable(source[i+1:i+1+seq_len].view(-1)) target = Variable(source[i+1:i+1+seq_len].view(-1))
return data, target return data, target
...@@ -147,13 +147,14 @@ def evaluate(data_source): ...@@ -147,13 +147,14 @@ def evaluate(data_source):
total_loss = 0 total_loss = 0
ntokens = len(corpus.dictionary) ntokens = len(corpus.dictionary)
hidden = model.init_hidden(eval_batch_size) hidden = model.init_hidden(eval_batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt): with torch.no_grad():
data, targets = get_batch(data_source, i, evaluation=True) for i in range(0, data_source.size(0) - 1, args.bptt):
output, hidden = model(data, hidden) data, targets = get_batch(data_source, i)
output_flat = output.view(-1, ntokens) output, hidden = model(data, hidden)
#total loss can overflow if accumulated in fp16. output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).data.float() #total loss can overflow if accumulated in fp16.
hidden = repackage_hidden(hidden) total_loss += len(data) * criterion(output_flat, targets).data.float()
hidden = repackage_hidden(hidden)
return to_python_float(total_loss) / len(data_source) return to_python_float(total_loss) / len(data_source)
......
...@@ -149,9 +149,9 @@ def repackage_hidden(h): ...@@ -149,9 +149,9 @@ def repackage_hidden(h):
# by the batchify function. The chunks are along dimension 0, corresponding # by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM. # to the seq_len dimension in the LSTM.
def get_batch(source, i, evaluation=False): def get_batch(source, i):
seq_len = min(args.bptt, len(source) - 1 - i) seq_len = min(args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation) data = Variable(source[i:i+seq_len])
target = Variable(source[i+1:i+1+seq_len].view(-1)) target = Variable(source[i+1:i+1+seq_len].view(-1))
return data, target return data, target
...@@ -162,13 +162,14 @@ def evaluate(data_source): ...@@ -162,13 +162,14 @@ def evaluate(data_source):
total_loss = 0 total_loss = 0
ntokens = len(corpus.dictionary) ntokens = len(corpus.dictionary)
hidden = model.init_hidden(eval_batch_size) hidden = model.init_hidden(eval_batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt): with torch.no_grad():
data, targets = get_batch(data_source, i, evaluation=True) for i in range(0, data_source.size(0) - 1, args.bptt):
output, hidden = model(data, hidden) data, targets = get_batch(data_source, i)
output_flat = output.view(-1, ntokens) output, hidden = model(data, hidden)
#total loss can overflow if accumulated in fp16. output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).data.float() #total loss can overflow if accumulated in fp16.
hidden = repackage_hidden(hidden) total_loss += len(data) * criterion(output_flat, targets).data.float()
hidden = repackage_hidden(hidden)
return to_python_float(total_loss) / len(data_source) return to_python_float(total_loss) / len(data_source)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment