Commit a31e591d authored by thomwolf's avatar thomwolf
Browse files

fix XLM tests

parent 447de34d
......@@ -566,7 +566,7 @@ class XLMPredLayer(nn.Module):
scores = self.proj(x)
outputs = (scores,) + outputs
if y is not None:
loss = F.cross_entropy(scores.view(-1, self.n_words), y, reduction='elementwise_mean')
loss = F.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction='elementwise_mean')
outputs = (loss,) + outputs
else:
scores = self.proj.log_prob(x)
......
......@@ -185,7 +185,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
model.eval()
outputs = model(input_ids)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment