Unverified Commit de344815 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Adds `PreTrainedModel.framework` attribute (#13817)



* Added `framework` attribute

* Update modeling_utils.py

* Update modeling_flax_utils.py

* Update modeling_tf_utils.py

* Update modeling_utils.py

* Update modeling_tf_utils.py

* Update modeling_tf_utils.py

* Update modeling_flax_utils.py

* Update modeling_tf_utils.py

* Update modeling_utils.py

* Update modeling_utils.py

* Update modeling_tf_utils.py

* Update modeling_flax_utils.py

* string -> str

* Update modeling_tf_utils.py

* string -> str

* fixup

* make flake happy
Co-authored-by: default avatarpatil-suraj <surajp815@gmail.com>
parent d70919e6
......@@ -119,6 +119,13 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
"""
return cls(config, **kwargs)
@property
def framework(self) -> str:
"""
:str: Identifies that this is a Flax model.
"""
return "flax"
@property
def config(self) -> PretrainedConfig:
return self._config
......
......@@ -652,6 +652,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"input_ids": tf.constant(DUMMY_INPUTS),
}
@property
def framework(self) -> str:
"""
:str: Identifies that this is a TensorFlow model.
"""
return "tf"
def __init__(self, config, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
if not isinstance(config, PretrainedConfig):
......
......@@ -461,6 +461,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
return {"input_ids": torch.tensor(DUMMY_INPUTS)}
@property
def framework(self) -> str:
"""
:str: Identifies that this is a PyTorch model.
"""
return "pt"
def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
super().__init__()
if not isinstance(config, PretrainedConfig):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment