"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a3daabfe14d374081e3abff736e41569e701bf60"
Unverified Commit d5a72b6e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Allow dataclasses to be jitted (#11886)

* fix_torch_device_generate_test

* remove @

* change dataclasses to flax ones

* fix typo

* fix jitted tests

* fix bert & electra
parent e6126e19
...@@ -11,16 +11,15 @@ ...@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import flax
import jaxlib.xla_extension as jax_xla import jaxlib.xla_extension as jax_xla
from .file_utils import ModelOutput from .file_utils import ModelOutput
@dataclass @flax.struct.dataclass
class FlaxBaseModelOutput(ModelOutput): class FlaxBaseModelOutput(ModelOutput):
""" """
Base class for model's outputs, with potential hidden states and attentions. Base class for model's outputs, with potential hidden states and attentions.
...@@ -46,7 +45,7 @@ class FlaxBaseModelOutput(ModelOutput): ...@@ -46,7 +45,7 @@ class FlaxBaseModelOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass @flax.struct.dataclass
class FlaxBaseModelOutputWithPast(ModelOutput): class FlaxBaseModelOutputWithPast(ModelOutput):
""" """
Base class for model's outputs, with potential hidden states and attentions. Base class for model's outputs, with potential hidden states and attentions.
...@@ -76,7 +75,7 @@ class FlaxBaseModelOutputWithPast(ModelOutput): ...@@ -76,7 +75,7 @@ class FlaxBaseModelOutputWithPast(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass @flax.struct.dataclass
class FlaxBaseModelOutputWithPooling(ModelOutput): class FlaxBaseModelOutputWithPooling(ModelOutput):
""" """
Base class for model's outputs that also contains a pooling of the last hidden states. Base class for model's outputs that also contains a pooling of the last hidden states.
...@@ -107,7 +106,7 @@ class FlaxBaseModelOutputWithPooling(ModelOutput): ...@@ -107,7 +106,7 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass @flax.struct.dataclass
class FlaxMaskedLMOutput(ModelOutput): class FlaxMaskedLMOutput(ModelOutput):
""" """
Base class for masked language models outputs. Base class for masked language models outputs.
...@@ -136,7 +135,7 @@ class FlaxMaskedLMOutput(ModelOutput): ...@@ -136,7 +135,7 @@ class FlaxMaskedLMOutput(ModelOutput):
FlaxCausalLMOutput = FlaxMaskedLMOutput FlaxCausalLMOutput = FlaxMaskedLMOutput
@dataclass @flax.struct.dataclass
class FlaxNextSentencePredictorOutput(ModelOutput): class FlaxNextSentencePredictorOutput(ModelOutput):
""" """
Base class for outputs of models predicting if two sentences are consecutive or not. Base class for outputs of models predicting if two sentences are consecutive or not.
...@@ -163,7 +162,7 @@ class FlaxNextSentencePredictorOutput(ModelOutput): ...@@ -163,7 +162,7 @@ class FlaxNextSentencePredictorOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass @flax.struct.dataclass
class FlaxSequenceClassifierOutput(ModelOutput): class FlaxSequenceClassifierOutput(ModelOutput):
""" """
Base class for outputs of sentence classification models. Base class for outputs of sentence classification models.
...@@ -189,7 +188,7 @@ class FlaxSequenceClassifierOutput(ModelOutput): ...@@ -189,7 +188,7 @@ class FlaxSequenceClassifierOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass @flax.struct.dataclass
class FlaxMultipleChoiceModelOutput(ModelOutput): class FlaxMultipleChoiceModelOutput(ModelOutput):
""" """
Base class for outputs of multiple choice models. Base class for outputs of multiple choice models.
...@@ -217,7 +216,7 @@ class FlaxMultipleChoiceModelOutput(ModelOutput): ...@@ -217,7 +216,7 @@ class FlaxMultipleChoiceModelOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass @flax.struct.dataclass
class FlaxTokenClassifierOutput(ModelOutput): class FlaxTokenClassifierOutput(ModelOutput):
""" """
Base class for outputs of token classification models. Base class for outputs of token classification models.
...@@ -243,7 +242,7 @@ class FlaxTokenClassifierOutput(ModelOutput): ...@@ -243,7 +242,7 @@ class FlaxTokenClassifierOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass @flax.struct.dataclass
class FlaxQuestionAnsweringModelOutput(ModelOutput): class FlaxQuestionAnsweringModelOutput(ModelOutput):
""" """
Base class for outputs of question answering models. Base class for outputs of question answering models.
......
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import numpy as np import numpy as np
import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -55,7 +55,7 @@ _CONFIG_FOR_DOC = "BertConfig" ...@@ -55,7 +55,7 @@ _CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer" _TOKENIZER_FOR_DOC = "BertTokenizer"
@dataclass @flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput): class FlaxBertForPreTrainingOutput(ModelOutput):
""" """
Output type of :class:`~transformers.BertForPreTraining`. Output type of :class:`~transformers.BertForPreTraining`.
......
...@@ -13,11 +13,11 @@ ...@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import numpy as np import numpy as np
import flax
import flax.linen as nn import flax.linen as nn
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -54,7 +54,7 @@ _CONFIG_FOR_DOC = "ElectraConfig" ...@@ -54,7 +54,7 @@ _CONFIG_FOR_DOC = "ElectraConfig"
_TOKENIZER_FOR_DOC = "ElectraTokenizer" _TOKENIZER_FOR_DOC = "ElectraTokenizer"
@dataclass @flax.struct.dataclass
class FlaxElectraForPreTrainingOutput(ModelOutput): class FlaxElectraForPreTrainingOutput(ModelOutput):
""" """
Output type of :class:`~transformers.ElectraForPreTraining`. Output type of :class:`~transformers.ElectraForPreTraining`.
......
...@@ -248,31 +248,19 @@ class FlaxModelTesterMixin: ...@@ -248,31 +248,19 @@ class FlaxModelTesterMixin:
@jax.jit @jax.jit
def model_jitted(input_ids, attention_mask=None, **kwargs): def model_jitted(input_ids, attention_mask=None, **kwargs):
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple() return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
with self.subTest("JIT Enabled"): with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict) jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"): with self.subTest("JIT Disabled"):
with jax.disable_jit(): with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict) outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs)) self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs): for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape) self.assertEqual(jitted_output.shape, output.shape)
@jax.jit
def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
return model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
# jitted function cannot return OrderedDict
with self.assertRaises(TypeError):
model_jitted_return_dict(**prepared_inputs_dict)
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment