Unverified Commit d7bd325b authored by Craig Chan's avatar Craig Chan Committed by GitHub
Browse files

Add missing Maskformer dataclass decorator, add dataclass check in ModelOutput...


Add missing Maskformer dataclass decorator, add dataclass check in ModelOutput for subclasses (#25638)

* Add @dataclass to MaskFormerPixelDecoderOutput

* Add dataclass check if subclass of ModelOutout

* Use unittest assertRaises rather than pytest per contribution doc

* Update src/transformers/utils/generic.py per suggested change
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 05de038f
...@@ -118,6 +118,7 @@ class MaskFormerPixelLevelModuleOutput(ModelOutput): ...@@ -118,6 +118,7 @@ class MaskFormerPixelLevelModuleOutput(ModelOutput):
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskFormerPixelDecoderOutput(ModelOutput): class MaskFormerPixelDecoderOutput(ModelOutput):
""" """
MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state MaskFormer's pixel decoder module output, practically a Feature Pyramid Network. It returns the last hidden state
......
...@@ -20,7 +20,7 @@ import tempfile ...@@ -20,7 +20,7 @@ import tempfile
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from collections.abc import MutableMapping from collections.abc import MutableMapping
from contextlib import ExitStack, contextmanager from contextlib import ExitStack, contextmanager
from dataclasses import fields from dataclasses import fields, is_dataclass
from enum import Enum from enum import Enum
from typing import Any, ContextManager, List, Tuple from typing import Any, ContextManager, List, Tuple
...@@ -314,7 +314,26 @@ class ModelOutput(OrderedDict): ...@@ -314,7 +314,26 @@ class ModelOutput(OrderedDict):
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)), lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
) )
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Subclasses of ModelOutput must use the @dataclass decorator
# This check is done in __init__ because the @dataclass decorator operates after __init_subclass__
# issubclass() would return True for issubclass(ModelOutput, ModelOutput) when False is needed
# Just need to check that the current class is not ModelOutput
is_modeloutput_subclass = self.__class__ != ModelOutput
if is_modeloutput_subclass and not is_dataclass(self):
raise TypeError(
f"{self.__module__}.{self.__class__.__name__} is not a dataclasss."
" This is a subclass of ModelOutput and so must use the @dataclass decorator."
)
def __post_init__(self): def __post_init__(self):
"""Check the ModelOutput dataclass.
Only occurs if @dataclass decorator has been used.
"""
class_fields = fields(self) class_fields = fields(self)
# Safety and consistency checks # Safety and consistency checks
......
...@@ -143,3 +143,23 @@ class ModelOutputTester(unittest.TestCase): ...@@ -143,3 +143,23 @@ class ModelOutputTester(unittest.TestCase):
unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec) unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x) self.assertEqual(x, unflattened_x)
class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
a: float
b: Optional[float] = None
c: Optional[float] = None
class ModelOutputSubclassTester(unittest.TestCase):
def test_direct_model_output(self):
# Check that direct usage of ModelOutput instantiates without errors
ModelOutput({"a": 1.1})
def test_subclass_no_dataclass(self):
# Check that a subclass of ModelOutput without @dataclass is invalid
# A valid subclass is inherently tested other unit tests above.
with self.assertRaises(TypeError):
ModelOutputTestNoDataclass(a=1.1, b=2.2, c=3.3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment