Commit 41df5ff7 authored by rprenger's avatar rprenger
Browse files

Fixing bug where the wrong token was used to index the log probabilities

parent e0bf5199
......@@ -270,7 +270,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen):
while context_length < maxlen:
types2use = None
if counter == 0:
tokens2use = tokens[:, :context_length]
......@@ -316,7 +316,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if output_logits is None:
output_context = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, :context_length],2)
indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
output_logits = torch.gather(output_context, 2, indices).squeeze(2)
else:
indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment