Commit 41df5ff7 authored by rprenger's avatar rprenger
Browse files

Fixing bug where the wrong token was used to index the log probabilities

parent e0bf5199
...@@ -270,7 +270,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -270,7 +270,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
lengths = torch.ones([batch_size]).long().cuda() * maxlen lengths = torch.ones([batch_size]).long().cuda() * maxlen
while context_length <= (maxlen): while context_length < maxlen:
types2use = None types2use = None
if counter == 0: if counter == 0:
tokens2use = tokens[:, :context_length] tokens2use = tokens[:, :context_length]
...@@ -316,7 +316,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -316,7 +316,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if output_logits is None: if output_logits is None:
output_context = F.log_softmax(output[:, :context_length, :], 2) output_context = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, :context_length],2) indices = torch.unsqueeze(tokens[:, 1:context_length+1],2)
output_logits = torch.gather(output_context, 2, indices).squeeze(2) output_logits = torch.gather(output_context, 2, indices).squeeze(2)
else: else:
indices = torch.unsqueeze(new_tokens,1).unsqueeze(2) indices = torch.unsqueeze(new_tokens,1).unsqueeze(2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment