Unverified Commit 39062d05 authored by guillaume-be's avatar guillaume-be Committed by GitHub
Browse files

Fixed target_mapping preparation for XLNet when batch size > 1 (incl. beam search) (#7267)

parent 4b3e55bd
......@@ -1313,7 +1313,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
target_mapping = torch.zeros(
(effective_batch_size, 1, sequence_length), dtype=torch.float, device=input_ids.device
)
target_mapping[0, 0, -1] = 1.0
target_mapping[:, 0, -1] = 1.0
inputs = {
"input_ids": input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment