Unverified Commit 7e876dca authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax BERT] Update deprecated 'split' method (#28012)

* [Flax BERT] Update deprecated 'split' method

* fix copies
parent e737446e
......@@ -1569,7 +1569,7 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
......
......@@ -1344,7 +1344,7 @@ class FlaxRobertaForQuestionAnsweringModule(nn.Module):
hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
......
......@@ -1365,7 +1365,7 @@ class FlaxRobertaPreLayerNormForQuestionAnsweringModule(nn.Module):
hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
......
......@@ -1359,7 +1359,7 @@ class FlaxXLMRobertaForQuestionAnsweringModule(nn.Module):
hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment