"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b21905e03d635920e41f1b59603fa9f18c62d076"
Unverified Commit 696e8a43 authored by Ratthachat (Jung)'s avatar Ratthachat (Jung) Committed by GitHub
Browse files

Add TFRag (#9002)

* Create modeling_tf_dpr.py

* Add TFDPR

* Add back TFPegasus, TFMarian, TFMBart, TFBlenderBot

last commit accidentally deleted these 4 lines, so I recover them back

* Add TFDPR

* Add TFDPR

* clean up some comments, add TF input-style doc string

* Add TFDPR

* Make return_dict=False as default

* Fix return_dict bug (in .from_pretrained)

* Add get_input_embeddings()

* Create test_modeling_tf_dpr.py

The current version is already passed all 27 tests!
Please see the test run at : 
https://colab.research.google.com/drive/1czS_m9zy5k-iSJbzA_DP1k1xAAC_sdkf?usp=sharing



* fix quality

* delete init weights

* run fix copies

* fix repo consis

* del config_class, load_tf_weights

They shoud be 'pytorch only'

* add config_class back

after removing it, test failed ... so totally only removing "use_tf_weights = None" on Lysandre suggestion

* newline after .. note::

* import tf, np (Necessary for ModelIntegrationTest)

* slow_test from_pretrained with from_pt=True

At the moment we don't have TF weights (since we don't have official official TF model)
Previously, I did not run slow test, so I missed this bug

* Add simple TFDPRModelIntegrationTest

Note that this is just a test that TF and Pytorch gives approx. the same output.
However, I could not test with the official DPR repo's output yet

* upload correct tf model

* remove position_ids as missing keys

* create modeling_tf_rag

* add tests for tf

* add tf tests

* revert wrong pt commit

* further refactor

* further refactor

* refactor

* Update modeling_tf_rag.py

- input_processing
- fix prepare_input_for_generation (mostly fix generate bug)
- bring back from_pretrained hack in order to test generate

* delete colab pieces of code

* Show case of greedy "generate"

Temporarily change from beam_search test to greedy_search test to show case that TF and PT do get equivalent output.

* cosmetic update

* correct typos

* update

* push some progress

* make easy check

* fix rag save from pretrained

* Update src/transformers/modeling_tf_utils.py

* remove commented out lines

* delete unnecessary lines

* add simple test case for nq_checkpoint

Add nq_checkpoint test to show that current version without hack still fails

* temporarily put ugly hack back again

* Add TFRagSequenceForGeneration!!

* __init__.py , import TFRagSequenceForGeneration

* Add TFRagSequence tests!

* rag init.py - add TFRagSequenceForGeneration

* fix from_pretrained

* fix prepare_inputs_for_generation

* Beam search for RagToken!

* minor clean up

* add tf.cast in TFRagModel

* More tf.cast

* Add all remaining tests (still have issues)

* delete all T5 related

* make style

* fix load weight prefix

* fix bart

* fix return_dict for tf_rag

make all tests pass .. Hooray

* fix some tests

* fix code quality

* fix qualtiy check

* finish tests tf rag

* add tf rag to docs

* remove TFT5 from docstring
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* remove TFT5 from docstring
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Delete outdated comments
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* improve doc strings

* add generative model classes

* fix adjust token logic

* refactor generate for TFRag

* using shape_list, not _get_shape
Co-authored-by: default avatarJulien Plu <plu.julien@gmail.com>

* axis=[1]->axis=1

* delete NEED_HELP comment

* improve readability
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* improve readability
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* improve readability
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Indicating model is in a developing state in docstrings

As suggested by Julien

* small last changes

* apply sylvains suggestions

* finish tf rag
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarpatrickvonplaten <patrick@huggingface.co>
Co-authored-by: default avatarJulien Plu <plu.julien@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 3ced9b3e
...@@ -296,7 +296,7 @@ TensorFlow and/or Flax. ...@@ -296,7 +296,7 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ | | ProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RAG | ✅ | ❌ | ✅ | | ❌ | | RAG | ✅ | ❌ | ✅ | | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Reformer | ✅ | ✅ | ✅ | ❌ | ❌ | | Reformer | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -94,3 +94,24 @@ RagTokenForGeneration ...@@ -94,3 +94,24 @@ RagTokenForGeneration
.. autoclass:: transformers.RagTokenForGeneration .. autoclass:: transformers.RagTokenForGeneration
:members: forward, generate :members: forward, generate
TFRagModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFRagModel
:members: call
TFRagSequenceForGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFRagSequenceForGeneration
:members: call, generate
TFRagTokenForGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFRagTokenForGeneration
:members: call, generate
...@@ -1130,6 +1130,13 @@ if is_tf_available(): ...@@ -1130,6 +1130,13 @@ if is_tf_available():
] ]
) )
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"]) _import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"])
_import_structure["models.rag"].extend(
[
"TFRagModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
]
)
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -2166,6 +2173,7 @@ if TYPE_CHECKING: ...@@ -2166,6 +2173,7 @@ if TYPE_CHECKING:
TFOpenAIGPTPreTrainedModel, TFOpenAIGPTPreTrainedModel,
) )
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from .models.rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.roberta import ( from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
......
...@@ -441,6 +441,7 @@ class TFGenerationMixin: ...@@ -441,6 +441,7 @@ class TFGenerationMixin:
encoder_outputs, encoder_outputs,
attention_mask, attention_mask,
use_cache, use_cache,
**kwargs
): ):
""" """
Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated Generate sequences for each example without beam search (num_beams == 1). All returned sequence are generated
...@@ -455,7 +456,7 @@ class TFGenerationMixin: ...@@ -455,7 +456,7 @@ class TFGenerationMixin:
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation( model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
) )
outputs = self(**model_inputs) outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :] next_token_logits = outputs[0][:, -1, :]
...@@ -609,6 +610,7 @@ class TFGenerationMixin: ...@@ -609,6 +610,7 @@ class TFGenerationMixin:
use_cache, use_cache,
forced_bos_token_id, forced_bos_token_id,
forced_eos_token_id, forced_eos_token_id,
**kwargs,
): ):
"""Generate sequences for each example with beam search.""" """Generate sequences for each example with beam search."""
...@@ -637,7 +639,7 @@ class TFGenerationMixin: ...@@ -637,7 +639,7 @@ class TFGenerationMixin:
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation( model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **kwargs
) )
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
......
...@@ -447,7 +447,7 @@ def input_processing(func, config, input_ids, **kwargs): ...@@ -447,7 +447,7 @@ def input_processing(func, config, input_ids, **kwargs):
return output return output
def load_tf_weights(model, resolved_archive_file): def load_tf_weights(model, resolved_archive_file, _prefix=None):
""" """
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes. Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
...@@ -493,6 +493,10 @@ def load_tf_weights(model, resolved_archive_file): ...@@ -493,6 +493,10 @@ def load_tf_weights(model, resolved_archive_file):
for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"): for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
# TF names always start with the model name so we ignore it # TF names always start with the model name so we ignore it
name = "/".join(weight_name.split("/")[1:]) name = "/".join(weight_name.split("/")[1:])
if _prefix is not None:
name = _prefix + "/" + name
saved_weights[name] = np.asarray(h5_layer_object[weight_name]) saved_weights[name] = np.asarray(h5_layer_object[weight_name])
# Add the updated name to the final list for computing missing/unexpected values # Add the updated name to the final list for computing missing/unexpected values
...@@ -501,6 +505,13 @@ def load_tf_weights(model, resolved_archive_file): ...@@ -501,6 +505,13 @@ def load_tf_weights(model, resolved_archive_file):
# Loop over each weights from the instantiated model and compare with the weights from the H5 file # Loop over each weights from the instantiated model and compare with the weights from the H5 file
for symbolic_weight in symbolic_weights: for symbolic_weight in symbolic_weights:
# TF names always start with the model name so we ignore it # TF names always start with the model name so we ignore it
if _prefix is not None:
delimeter = len(_prefix.split("/"))
symbolic_weight_name = "/".join(
symbolic_weight.name.split("/")[:delimeter]
+ symbolic_weight.name.split("/")[delimeter + 1 :]
)
else:
symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
# here we check if the current weight is among the weights from the H5 file # here we check if the current weight is among the weights from the H5 file
...@@ -603,6 +614,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -603,6 +614,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
# a list of re pattern of tensor names to ignore from the weights when loading the model weights # a list of re pattern of tensor names to ignore from the weights when loading the model weights
# (and avoid unnecessary warnings). # (and avoid unnecessary warnings).
_keys_to_ignore_on_load_unexpected = None _keys_to_ignore_on_load_unexpected = None
_requires_load_weight_prefix = False
@property @property
def dummy_inputs(self) -> Dict[str, tf.Tensor]: def dummy_inputs(self) -> Dict[str, tf.Tensor]:
...@@ -741,10 +753,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -741,10 +753,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
def get_prefix_bias_name(self) -> Union[None, str]: def get_prefix_bias_name(self) -> Union[None, str]:
""" """
Get the concatenated prefix name of the bias from the model name to the parent layer Get the concatenated _prefix name of the bias from the model name to the parent layer
Return: Return:
:obj:`str`: The prefix name of the bias. :obj:`str`: The _prefix name of the bias.
""" """
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return None return None
...@@ -1052,7 +1064,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -1052,7 +1064,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
a user or organization name, like ``dbmdz/bert-base-german-cased``. a user or organization name, like ``dbmdz/bert-base-german-cased``.
- A path to a `directory` containing model weights saved using - A path to a `directory` containing model weights saved using
:func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. :func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In - A path or url to a `PyTorch state_dict save file` (e.g, ``./pt_model/pytorch_model.bin``). In
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
as ``config`` argument. This loading path is slower than converting the PyTorch model in a as ``config`` argument. This loading path is slower than converting the PyTorch model in a
...@@ -1151,6 +1163,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -1151,6 +1163,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None) mirror = kwargs.pop("mirror", None)
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
if is_offline_mode() and not local_files_only: if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True") logger.info("Offline mode: forcing local_files_only=True")
...@@ -1230,6 +1243,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -1230,6 +1243,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
config.name_or_path = pretrained_model_name_or_path config.name_or_path = pretrained_model_name_or_path
# composed models, *e.g.* TFRag, require special treatment when it comes to loading
# pre-trained weights.
if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None:
model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name")
# Instantiate model. # Instantiate model.
model = cls(config, *model_args, **model_kwargs) model = cls(config, *model_args, **model_kwargs)
...@@ -1239,13 +1257,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -1239,13 +1257,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True) return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
# we might need to extend the variable scope for composite models
if load_weight_prefix is not None:
with tf.compat.v1.variable_scope(load_weight_prefix):
model(model.dummy_inputs) # build the network with dummy inputs
else:
model(model.dummy_inputs) # build the network with dummy inputs model(model.dummy_inputs) # build the network with dummy inputs
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
# 'by_name' allow us to do transfer learning by skipping/adding layers # 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try: try:
missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file) missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file, load_weight_prefix)
except OSError: except OSError:
raise OSError( raise OSError(
"Unable to load weights from h5 file. " "Unable to load weights from h5 file. "
......
...@@ -553,7 +553,7 @@ class TFAutoModel(object): ...@@ -553,7 +553,7 @@ class TFAutoModel(object):
@classmethod @classmethod
@replace_list_option_in_docstrings(TF_MODEL_MAPPING, use_model_types=False) @replace_list_option_in_docstrings(TF_MODEL_MAPPING, use_model_types=False)
def from_config(cls, config): def from_config(cls, config, **kwargs):
r""" r"""
Instantiates one of the base model classes of the library from a configuration. Instantiates one of the base model classes of the library from a configuration.
...@@ -575,7 +575,7 @@ class TFAutoModel(object): ...@@ -575,7 +575,7 @@ class TFAutoModel(object):
>>> model = TFAutoModel.from_config(config) >>> model = TFAutoModel.from_config(config)
""" """
if type(config) in TF_MODEL_MAPPING.keys(): if type(config) in TF_MODEL_MAPPING.keys():
return TF_MODEL_MAPPING[type(config)](config) return TF_MODEL_MAPPING[type(config)](config, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1037,7 +1037,7 @@ class TFAutoModelForSeq2SeqLM: ...@@ -1037,7 +1037,7 @@ class TFAutoModelForSeq2SeqLM:
@classmethod @classmethod
@replace_list_option_in_docstrings(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, use_model_types=False) @replace_list_option_in_docstrings(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, use_model_types=False)
def from_config(cls, config): def from_config(cls, config, **kwargs):
r""" r"""
Instantiates one of the model classes of the library---with a sequence-to-sequence language modeling Instantiates one of the model classes of the library---with a sequence-to-sequence language modeling
head---from a configuration. head---from a configuration.
...@@ -1061,7 +1061,7 @@ class TFAutoModelForSeq2SeqLM: ...@@ -1061,7 +1061,7 @@ class TFAutoModelForSeq2SeqLM:
>>> model = TFAutoModelForSeq2SeqLM.from_config(config) >>> model = TFAutoModelForSeq2SeqLM.from_config(config)
""" """
if type(config) in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys(): if type(config) in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys():
return TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)](config) return TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)](config, **kwargs)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
......
...@@ -1015,13 +1015,16 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -1015,13 +1015,16 @@ class TFBartDecoder(tf.keras.layers.Layer):
class TFBartMainLayer(tf.keras.layers.Layer): class TFBartMainLayer(tf.keras.layers.Layer):
config_class = BartConfig config_class = BartConfig
def __init__(self, config: BartConfig, **kwargs): def __init__(self, config: BartConfig, load_weight_prefix=None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.config = config self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared") self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: # set tf scope correctly
if load_weight_prefix is None:
load_weight_prefix = "model.shared"
with tf.compat.v1.variable_scope(load_weight_prefix) as shared_abs_scope_name:
pass pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
...@@ -1157,10 +1160,13 @@ class TFBartMainLayer(tf.keras.layers.Layer): ...@@ -1157,10 +1160,13 @@ class TFBartMainLayer(tf.keras.layers.Layer):
BART_START_DOCSTRING, BART_START_DOCSTRING,
) )
class TFBartModel(TFBartPretrainedModel): class TFBartModel(TFBartPretrainedModel):
def __init__(self, config: BartConfig, *inputs, **kwargs):
_requires_load_weight_prefix = True
def __init__(self, config: BartConfig, load_weight_prefix=None, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.model = TFBartMainLayer(config, name="model") self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
def get_encoder(self): def get_encoder(self):
return self.model.encoder return self.model.encoder
...@@ -1263,9 +1269,11 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1263,9 +1269,11 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
r"model.decoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight",
] ]
def __init__(self, config, *inputs, **kwargs): _requires_load_weight_prefix = True
def __init__(self, config, load_weight_prefix=None, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs) super().__init__(config, *inputs, **kwargs)
self.model = TFBartMainLayer(config, name="model") self.model = TFBartMainLayer(config, load_weight_prefix=load_weight_prefix, name="model")
self.use_cache = config.use_cache self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight( self.final_logits_bias = self.add_weight(
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_torch_available from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
_import_structure = { _import_structure = {
...@@ -30,6 +30,9 @@ _import_structure = { ...@@ -30,6 +30,9 @@ _import_structure = {
if is_torch_available(): if is_torch_available():
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"] _import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
if is_tf_available():
_import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_rag import RagConfig from .configuration_rag import RagConfig
...@@ -39,6 +42,9 @@ if TYPE_CHECKING: ...@@ -39,6 +42,9 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
if is_tf_available():
from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
else: else:
import importlib import importlib
import os import os
......
This diff is collapsed.
...@@ -1332,6 +1332,25 @@ class TFPegasusModel: ...@@ -1332,6 +1332,25 @@ class TFPegasusModel:
requires_tf(self) requires_tf(self)
class TFRagModel:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFRagSequenceForGeneration:
def __init__(self, *args, **kwargs):
requires_tf(self)
class TFRagTokenForGeneration:
def __init__(self, *args, **kwargs):
requires_tf(self)
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
...@@ -121,6 +121,9 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -121,6 +121,9 @@ IGNORE_NON_AUTO_CONFIGURED = [
"TFGPT2DoubleHeadsModel", "TFGPT2DoubleHeadsModel",
"TFMT5EncoderModel", "TFMT5EncoderModel",
"TFOpenAIGPTDoubleHeadsModel", "TFOpenAIGPTDoubleHeadsModel",
"TFRagModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
"TFT5EncoderModel", "TFT5EncoderModel",
"Wav2Vec2ForCTC", "Wav2Vec2ForCTC",
"XLMForQuestionAnswering", "XLMForQuestionAnswering",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment