Commit d98a384c authored by patrickvonplaten's avatar patrickvonplaten
Browse files

fix bug in prepare inputs for language generation for xlm for effective batch_size > 1

parent 81db12c3
......@@ -674,7 +674,8 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
mask_token_id = self.config.mask_token_id
lang_id = self.config.lang_id
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
effective_batch_size = input_ids.shape[0]
mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
input_ids = torch.cat([input_ids, mask_token], dim=1)
if lang_id is not None:
langs = torch.full_like(input_ids, lang_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment