Unverified Commit 163ac3d3 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Add Switch transformers (#19323)



* first commit

* add more comments

* add router v1

* clean up

- remove `tf` modeling files

* clean up

- remove `tf` modeling files

* clean up

* v0 routers

* added more router

- Implemented `ExpertsChooseMaskedRouter`

- added tests
- 2 more routers to implement

* last router

* improved docstring

- completed the docstring in `router.py`
- added more args in the config

* v0 sparse mlp

* replace wrong naming

* forward pass run

* update MOE layer

* small router update

* fixup

* consistency

* remove scatter router

* remove abstract layer

* update test and model for integration testing

* v1 conversion

* update

* hardcode hack

* all keys match

* add gin conversion, without additional libraries

* update conversion sctipy

* delete router file

* update tests wrt router deletion

* fix router issues

* update expert code

* update, logits match, code needsREFACTORING

* Refactor code
Co-authored-by: default avatarYounes Belkada <younesbelkada@users.noreply.github.com>

* add generate tests
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>

* add support for router loss
Co-authored-by: default avatarYounes Belkada <younesbelkada@users.noreply.github.com>

* fix forward error

* refactor a bit

* remove `FlaxSwitchTransformers` modules

* more tests pass

* Update code
Co-authored-by: default avatarYounes Belkada <younesbelkada@users.noreply.github.com>

* fixup

* fix tests

* fix doc

* fix doc + tokenization

* fix tokenizer test

* fix test

* fix loss output

* update code for backward pass

* add loss support

* update documentation

* fix documentation, clean tokenizer

* more doc fix, cleanup example_switch

* fix failing test

* fix test

* fix test

* fix loss issue

* move layer

* update doc and fix router capacity usage

* fixup

* add sparse mlp index for documentation on hub

* fixup

* test sparse mix architecture

* Apply suggestions from code review

* Update docs/source/en/model_doc/switch_transformers.mdx

* fixup on update

* fix tests

* fix another test

* attempt fix

* Update src/transformers/models/switch_transformers/configuration_switch_transformers.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/switch_transformers/convert_switch_transformers_original_flax_checkpoint_to_pytorch.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* try

* all tests pass

* fix jitter noise

* Apply suggestions from code review

* doc tests pass

* Update src/transformers/models/switch_transformers/modeling_switch_transformers.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/models/switch_transformers/modeling_switch_transformers.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* remove assert

* change config order

* fix readme japanese

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* remove parallelizable tests + add one liners

* remove ONNX config

* fix nits

- add `T5Tokenizer` in auto mapping
- remove `Switch Transformers` from ONNX supported models

* remove `_get_router`

* remove asserts

* add check in test for `router_dtype`

* add `SwitchTransformersConfig` in `run_pipeline_test`

* Update tests/pipelines/test_pipelines_summarization.py

* add huge model conversion script

* fix slow tests

- add better casting for `Linear8bitLt`
- remove `torchscript` tests

* add make dir

* style on new script

* fix nits

- doctest
- remove `_keys_to_ignore_on_load_unexpected`

* Update src/transformers/models/switch_transformers/configuration_switch_transformers.py

* add google as authors

* fix year

* remove last `assert` statements

* standardize vertical spaces

* fix failing import

* fix another failing test

* Remove strange àuthorized_keys`

* removing todo and padding that is never used
Co-authored-by: default avatarArthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: default avatarybelkada <younes@huggingface.co>
Co-authored-by: default avatarYounes Belkada <younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarArthur Zucker <arthur@huggingface.co>
parent 55ba3190
......@@ -5064,6 +5064,51 @@ class Swinv2PreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST = None
class SwitchTransformersEncoderModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SwitchTransformersForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SwitchTransformersModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SwitchTransformersPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SwitchTransformersSparseMLP(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SwitchTransformersTop1Router(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
T5_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -20,6 +20,7 @@ from transformers import (
LEDConfig,
LongT5Config,
SummarizationPipeline,
SwitchTransformersConfig,
T5Config,
pipeline,
)
......@@ -54,8 +55,8 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
)
self.assertEqual(outputs, [{"summary_text": ANY(str)}])
if not isinstance(model.config, (T5Config, LongT5Config, LEDConfig)):
# LED, T5, LongT5 can handle it.
if not isinstance(model.config, (SwitchTransformersConfig, T5Config, LongT5Config, LEDConfig)):
# Switch Transformers, LED, T5, LongT5 can handle it.
# Too long.
with self.assertRaises(Exception):
outputs = summarizer("This " * 1000)
......
......@@ -39,6 +39,7 @@ PRIVATE_MODELS = [
"LongT5Stack",
"RealmBertModel",
"T5Stack",
"SwitchTransformersStack",
"TFDPRSpanPredictor",
]
......
......@@ -5,6 +5,7 @@ docs/source/en/autoclass_tutorial.mdx
docs/source/en/task_summary.mdx
docs/source/en/model_doc/markuplm.mdx
docs/source/en/model_doc/speech_to_text.mdx
docs/source/en/model_doc/switch_transformers.mdx
docs/source/en/model_doc/t5.mdx
docs/source/en/model_doc/t5v1.1.mdx
docs/source/en/model_doc/byt5.mdx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment