Unverified Commit 7e876dca authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Flax BERT] Update deprecated 'split' method (#28012)

* [Flax BERT] Update deprecated 'split' method

* fix copies
parent e737446e
...@@ -1569,7 +1569,7 @@ class FlaxBertForQuestionAnsweringModule(nn.Module): ...@@ -1569,7 +1569,7 @@ class FlaxBertForQuestionAnsweringModule(nn.Module):
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states) logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
......
...@@ -1344,7 +1344,7 @@ class FlaxRobertaForQuestionAnsweringModule(nn.Module): ...@@ -1344,7 +1344,7 @@ class FlaxRobertaForQuestionAnsweringModule(nn.Module):
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states) logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
......
...@@ -1365,7 +1365,7 @@ class FlaxRobertaPreLayerNormForQuestionAnsweringModule(nn.Module): ...@@ -1365,7 +1365,7 @@ class FlaxRobertaPreLayerNormForQuestionAnsweringModule(nn.Module):
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states) logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
......
...@@ -1359,7 +1359,7 @@ class FlaxXLMRobertaForQuestionAnsweringModule(nn.Module): ...@@ -1359,7 +1359,7 @@ class FlaxXLMRobertaForQuestionAnsweringModule(nn.Module):
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.qa_outputs(hidden_states) logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
start_logits = start_logits.squeeze(-1) start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment