Unverified Commit a72f1c9f authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Add `LongT5` model (#16792)



* Initial commit

* Make some fixes

* Make PT model full forward pass

* Drop TF & Flax implementation, fix copies etc

* Add Flax model and update some corresponding stuff

* Drop some TF things

* Update config and flax local attn

* Add encoder_attention_type to config

* .

* Update docs

* Do some cleansing

* Fix some issues -> make style; add some docs

* Fix position_bias + mask addition + Update tests

* Fix repo consistency

* Fix model consistency by removing flax operation over attn_mask

* [WIP] Add PT TGlobal LongT5

* .

* [WIP] Add flax tglobal model

* [WIP] Update flax model to use the right attention type in the encoder

* Fix flax tglobal model forward pass

* Make the use of global_relative_attention_bias

* Add test suites for TGlobal model

* Fix minor bugs, clean code

* Fix pt-flax equivalence though not convinced with correctness

* Fix LocalAttn implementation to match the original impl. + update READMEs

* Few updates

* Update: [Flax] improve large model init and loading #16148

* Add ckpt conversion script accoring to #16853 + handle torch device placement

* Minor updates to conversion script.

* Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM

* gpu support + dtype fix

* Apply some suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* * Remove (de)parallelize stuff
* Edit shape comments
* Update README.md
* make fix-copies

* Remove caching logic for local & tglobal attention

* Apply another batch of suggestions from code review

* Add missing checkpoints
* Format converting scripts
* Drop (de)parallelize links from longT5 mdx

* Fix converting script + revert config file change

* Revert "Remove caching logic for local & tglobal attention"

This reverts commit 2a619828f6ddc3e65bd9bb1725a12b77fa883a46.

* Stash caching logic in Flax model

* Make side relative bias used always

* Drop caching logic in PT model

* Return side bias as it was

* Drop all remaining model parallel logic

* Remove clamp statements

* Move test files to the proper place

* Update docs with new version of hf-doc-builder

* Fix test imports

* Make some minor improvements

* Add missing checkpoints to docs
* Make TGlobal model compatible with torch.onnx.export
* Replace some np.ndarray with jnp.ndarray

* Fix TGlobal for ONNX conversion + update docs

* fix _make_global_fixed_block_ids and masked neg  value

* update flax model

* style and quality

* fix imports

* remove load_tf_weights_in_longt5 from init and fix copies

* add slow test for TGlobal model

* typo fix

* Drop obsolete is_parallelizable and one warning

* Update __init__ files to fix repo-consistency

* fix pipeline test

* Fix some device placements

* [wip]: Update tests -- need to generate summaries to update expected_summary

* Fix quality

* Update LongT5 model card

* Update (slow) summarization tests

* make style

* rename checkpoitns

* finish

* fix flax tests
Co-authored-by: default avatarphungvanduy <pvduy23@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarpatil-suraj <surajp815@gmail.com>
parent 1690094b
......@@ -725,6 +725,27 @@ class FlaxGPTJPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"])
class FlaxLongT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLongT5Model(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxLongT5PreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxMarianModel(metaclass=DummyObject):
_backends = ["flax"]
......
......@@ -2605,6 +2605,37 @@ class LongformerSelfAttention(metaclass=DummyObject):
requires_backends(self, ["torch"])
LONGT5_PRETRAINED_MODEL_ARCHIVE_LIST = None
class LongT5EncoderModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LongT5ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LongT5Model(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LongT5PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
background : coronary artery disease ( cad ) is the emerging cause of morbidity and mortality in developing world . it provides an excellent resolution for visualization of the coronaryarteries for catheter - based or operating interventions . although the association of this technique with major complications such as mortality is highly uncommon , it is frequently associated with various cardiac and noncardiac complications.materials and methods : in aortic stenosis , we aimed to report the diagnostic performance of 128-slice computed tomography coronary angiogram in 50 patients undergoing for major noncoron ary cardiac surgery referred
This diff is collapsed.
This diff is collapsed.
......@@ -213,6 +213,8 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("blenderbot-small", "facebook/blenderbot_small-90M"),
("blenderbot", "facebook/blenderbot-400M-distill"),
("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"),
("longt5", "google/LongT5-Local-Base"),
("longt5", "google/LongT5-TGlobal-Base"),
}
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
......
......@@ -18,6 +18,7 @@ from transformers import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
LEDConfig,
LongT5Config,
SummarizationPipeline,
T5Config,
pipeline,
......@@ -54,8 +55,8 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
)
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
if not isinstance(model.config, (T5Config, LEDConfig)):
# LED, T5 can handle it.
if not isinstance(model.config, (T5Config, LongT5Config, LEDConfig)):
# LED, T5, LongT5 can handle it.
# Too long.
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
......
......@@ -36,6 +36,7 @@ PATH_TO_DOC = "docs/source/en"
# Update this list with models that are supposed to be private.
PRIVATE_MODELS = [
"DPRSpanPredictor",
"LongT5Stack",
"RealmBertModel",
"T5Stack",
"TFDPRSpanPredictor",
......
......@@ -36,6 +36,7 @@ src/transformers/models/layoutlmv2/modeling_layoutlmv2.py
src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
src/transformers/models/longformer/modeling_longformer.py
src/transformers/models/longformer/modeling_tf_longformer.py
src/transformers/models/longt5/modeling_longt5.py
src/transformers/models/marian/modeling_marian.py
src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mobilebert/modeling_mobilebert.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment