Commit a31e591d authored by thomwolf's avatar thomwolf
Browse files

fix XLM tests

parent 447de34d
...@@ -566,7 +566,7 @@ class XLMPredLayer(nn.Module): ...@@ -566,7 +566,7 @@ class XLMPredLayer(nn.Module):
scores = self.proj(x) scores = self.proj(x)
outputs = (scores,) + outputs outputs = (scores,) + outputs
if y is not None: if y is not None:
loss = F.cross_entropy(scores.view(-1, self.n_words), y, reduction='elementwise_mean') loss = F.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction='elementwise_mean')
outputs = (loss,) + outputs outputs = (loss,) + outputs
else: else:
scores = self.proj.log_prob(x) scores = self.proj.log_prob(x)
......
...@@ -185,7 +185,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester): ...@@ -185,7 +185,7 @@ class XLMModelTest(CommonTestCases.CommonModelTester):
model.eval() model.eval()
outputs = model(input_ids) outputs = model(input_ids)
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = outputs
outputs = model(input_ids, start_positions=sequence_labels, outputs = model(input_ids, start_positions=sequence_labels,
end_positions=sequence_labels, end_positions=sequence_labels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment