Commit f3e0218f authored by LysandreJik's avatar LysandreJik
Browse files

Correct device assignment in run_generation

parent 0820bb05
...@@ -125,7 +125,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= ...@@ -125,7 +125,7 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
if xlm_lang is not None: if xlm_lang is not None:
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1]).view(1, -1) inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
next_token_logits = outputs[0][0, -1, :] / temperature next_token_logits = outputs[0][0, -1, :] / temperature
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment