Unverified Commit d5a72b6e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Allow dataclasses to be jitted (#11886)

* fix_torch_device_generate_test

* remove @

* change dataclasses to flax ones

* fix typo

* fix jitted tests

* fix bert & electra
parent e6126e19
......@@ -11,16 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import flax
import jaxlib.xla_extension as jax_xla
from .file_utils import ModelOutput
@dataclass
@flax.struct.dataclass
class FlaxBaseModelOutput(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
......@@ -46,7 +45,7 @@ class FlaxBaseModelOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxBaseModelOutputWithPast(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
......@@ -76,7 +75,7 @@ class FlaxBaseModelOutputWithPast(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxBaseModelOutputWithPooling(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
......@@ -107,7 +106,7 @@ class FlaxBaseModelOutputWithPooling(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxMaskedLMOutput(ModelOutput):
"""
Base class for masked language models outputs.
......@@ -136,7 +135,7 @@ class FlaxMaskedLMOutput(ModelOutput):
FlaxCausalLMOutput = FlaxMaskedLMOutput
@dataclass
@flax.struct.dataclass
class FlaxNextSentencePredictorOutput(ModelOutput):
"""
Base class for outputs of models predicting if two sentences are consecutive or not.
......@@ -163,7 +162,7 @@ class FlaxNextSentencePredictorOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.
......@@ -189,7 +188,7 @@ class FlaxSequenceClassifierOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxMultipleChoiceModelOutput(ModelOutput):
"""
Base class for outputs of multiple choice models.
......@@ -217,7 +216,7 @@ class FlaxMultipleChoiceModelOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxTokenClassifierOutput(ModelOutput):
"""
Base class for outputs of token classification models.
......@@ -243,7 +242,7 @@ class FlaxTokenClassifierOutput(ModelOutput):
attentions: Optional[Tuple[jax_xla.DeviceArray]] = None
@dataclass
@flax.struct.dataclass
class FlaxQuestionAnsweringModelOutput(ModelOutput):
"""
Base class for outputs of question answering models.
......
......@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import numpy as np
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
......@@ -55,7 +55,7 @@ _CONFIG_FOR_DOC = "BertConfig"
_TOKENIZER_FOR_DOC = "BertTokenizer"
@dataclass
@flax.struct.dataclass
class FlaxBertForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.BertForPreTraining`.
......
......@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import numpy as np
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
......@@ -54,7 +54,7 @@ _CONFIG_FOR_DOC = "ElectraConfig"
_TOKENIZER_FOR_DOC = "ElectraTokenizer"
@dataclass
@flax.struct.dataclass
class FlaxElectraForPreTrainingOutput(ModelOutput):
"""
Output type of :class:`~transformers.ElectraForPreTraining`.
......
......@@ -248,31 +248,19 @@ class FlaxModelTesterMixin:
@jax.jit
def model_jitted(input_ids, attention_mask=None, **kwargs):
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs).to_tuple()
return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**prepared_inputs_dict)
jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**prepared_inputs_dict)
outputs = model_jitted(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):
self.assertEqual(jitted_output.shape, output.shape)
@jax.jit
def model_jitted_return_dict(input_ids, attention_mask=None, **kwargs):
return model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
# jitted function cannot return OrderedDict
with self.assertRaises(TypeError):
model_jitted_return_dict(**prepared_inputs_dict)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment