Unverified Commit c02cd95c authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

GPT-J-6B (#13022)



* Test GPTJ implementation

* Fixed conflicts

* Update __init__.py

* Update __init__.py

* change GPT_J to GPTJ

* fix missing imports and typos

* use einops for now
(need to change to torch ops later)

* Use torch ops instead of einsum

* remove einops deps

* Update configuration_auto.py

* Added GPT J

* Update gptj.rst

* Update __init__.py

* Update test_modeling_gptj.py

* Added GPT J

* Changed configs to match GPT2 instead of GPT Neo

* Removed non-existent sequence model

* Update configuration_auto.py

* Update configuration_auto.py

* Update configuration_auto.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* Progress on updating configs to agree with GPT2

* Update modeling_gptj.py

* num_layers -> n_layer

* layer_norm_eps -> layer_norm_epsilon

* attention_layers -> num_hidden_layers

* Update modeling_gptj.py

* attention_pdrop -> attn_pdrop

* hidden_act -> activation_function

* Update configuration_gptj.py

* Update configuration_gptj.py

* Update configuration_gptj.py

* Update configuration_gptj.py

* Update configuration_gptj.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* Update modeling_gptj.py

* fix layernorm and lm_head size
delete attn_type

* Update docs/source/model_doc/gptj.rst
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* removed claim that GPT J uses local attention

* Removed GPTJForSequenceClassification

* Update src/transformers/models/gptj/configuration_gptj.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Removed unsupported boilerplate

* Update tests/test_modeling_gptj.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>

* Update tests/test_modeling_gptj.py
Co-authored-by: default avatarEric Hallahan <eric@hallahans.name>

* Update tests/test_modeling_gptj.py
Co-authored-by: default avatarEric Hallahan <eric@hallahans.name>

* Update tests/test_modeling_gptj.py
Co-authored-by: default avatarEric Hallahan <eric@hallahans.name>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update __init__.py

* Update configuration_gptj.py

* Update modeling_gptj.py

* Corrected indentation

* Remove stray backslash

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

* Delete .DS_Store

* Update docs to match

* Remove tf loading

* Remove config.jax

* Remove stray `else:` statement

* Remove references to `load_tf_weights_in_gptj`

* Adapt tests to match output from GPT-J 6B

* Apply suggestions from code review
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Default `activation_function` to `gelu_new`

- Specify the approximate formulation of GELU to ensure parity with the default setting of `jax.nn.gelu()`

* Fix part of the config documentation

* Revert "Update configuration_auto.py"

This reverts commit e9860e9c043b6ebf57a0e705044e9ec9ba2263bb.

* Revert "Update configuration_auto.py"

This reverts commit cfaaae4c4dc70f1fbe9abd60fc8bd0b863b8c011.

* Revert "Update configuration_auto.py"

This reverts commit 687788954fd0cfbc567fa1202d56a4ff9271944f.

* Revert "Update configuration_auto.py"

This reverts commit 194d024ea87d4fcef0dcb08e57f52c47511a9fc6.

* Hyphenate GPT-J

* Undid sorting of the models alphabetically

* Reverting previous commit

* fix style and quality issues

* Update docs/source/model_doc/gptj.rst
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/__init__.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update tests/test_modeling_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/__init__.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/configuration_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/configuration_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/configuration_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Replaced GPTJ-specific code with generic code

* Update src/transformers/models/gptj/modeling_gptj.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Made the code always use rotary positional encodings

* Update index.rst

* Fix documentation

* Combine attention classes

- Condense all attention operations into `GPTJAttention`
- Replicate GPT-2 and improve code clarity by renaming `GPTJAttention.attn_pdrop` and `GPTJAttention.resid_pdrop` to `GPTJAttention.attn_dropout` and `GPTJAttention.resid_dropout`

* Removed `config.rotary_dim` from tests

* Update test_modeling_gptj.py

* Update test_modeling_gptj.py

* Fix formatting

* Removed depreciated argument `layer_id` to `GPTJAttention`

* Update modeling_gptj.py

* Update modeling_gptj.py

* Fix code quality

* Restore model functionality

* Save `lm_head.weight` in checkpoints

* Fix crashes when loading with reduced precision

* refactor self._attn(...)` and rename layer weights"

* make sure logits are in fp32 for sampling

* improve docs

* Add `GPTJForCausalLM` to `TextGenerationPipeline` whitelist

* Added GPT-J to the README

* Fix doc/readme consistency

* Add rough parallelization support

- Remove unused imports and variables
- Clean up docstrings
- Port experimental parallelization code from GPT-2 into GPT-J

* Clean up loose ends

* Fix index.rst
Co-authored-by: default avatarkurumuz <kurumuz1@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarEric Hallahan <eric@hallahans.name>
Co-authored-by: default avatarLeo Gao <54557097+leogao2@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: your_github_username <your_github_email>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent e53af030
......@@ -240,6 +240,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[Funnel Transformer](https://huggingface.co/transformers/model_doc/funnel.html)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
1. **[GPT](https://huggingface.co/transformers/model_doc/gpt.html)** (from OpenAI) released with the paper [Improving Language Understanding by Generative Pre-Training](https://blog.openai.com/language-unsupervised/) by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever.
1. **[GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html)** (from OpenAI) released with the paper [Language Models are Unsupervised Multitask Learners](https://blog.openai.com/better-language-models/) by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
1. **[GPT-J](https://huggingface.co/transformers/model_doc/gptj.html)** (from EleutherAI) released in the repository [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax/) by Ben Wang and Aran Komatsuzaki.
1. **[GPT Neo](https://huggingface.co/transformers/model_doc/gpt_neo.html)** (from EleutherAI) released in the repository [EleutherAI/gpt-neo](https://github.com/EleutherAI/gpt-neo) by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
1. **[Hubert](https://huggingface.co/transformers/model_doc/hubert.html)** (from Facebook) released with the paper [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
1. **[I-BERT](https://huggingface.co/transformers/model_doc/ibert.html)** (from Berkeley) released with the paper [I-BERT: Integer-only BERT Quantization](https://arxiv.org/abs/2101.01321) by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer
......
......@@ -191,116 +191,118 @@ Supported models
30. :doc:`GPT-2 <model_doc/gpt2>` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask
Learners <https://blog.openai.com/better-language-models/>`__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David
Luan, Dario Amodei** and Ilya Sutskever**.
31. :doc:`GPT Neo <model_doc/gpt_neo>` (from EleutherAI) released in the repository `EleutherAI/gpt-neo
31. :doc:`GPT-J <model_doc/gptj>` (from EleutherAI) released in the repository `kingoflolz/mesh-transformer-jax
<https://github.com/kingoflolz/mesh-transformer-jax/>`__ by Ben Wang and Aran Komatsuzaki.
32. :doc:`GPT Neo <model_doc/gpt_neo>` (from EleutherAI) released in the repository `EleutherAI/gpt-neo
<https://github.com/EleutherAI/gpt-neo>`__ by Sid Black, Stella Biderman, Leo Gao, Phil Wang and Connor Leahy.
32. :doc:`Hubert <model_doc/hubert>` (from Facebook) released with the paper `HuBERT: Self-Supervised Speech
33. :doc:`Hubert <model_doc/hubert>` (from Facebook) released with the paper `HuBERT: Self-Supervised Speech
Representation Learning by Masked Prediction of Hidden Units <https://arxiv.org/abs/2106.07447>`__ by Wei-Ning Hsu,
Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia, Ruslan Salakhutdinov, Abdelrahman Mohamed.
33. :doc:`I-BERT <model_doc/ibert>` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization
34. :doc:`I-BERT <model_doc/ibert>` (from Berkeley) released with the paper `I-BERT: Integer-only BERT Quantization
<https://arxiv.org/abs/2101.01321>`__ by Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney, Kurt Keutzer
34. :doc:`LayoutLM <model_doc/layoutlm>` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training
35. :doc:`LayoutLM <model_doc/layoutlm>` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training
of Text and Layout for Document Image Understanding <https://arxiv.org/abs/1912.13318>`__ by Yiheng Xu, Minghao Li,
Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou.
35. :doc:`LayoutLMv2 <model_doc/layoutlmv2>` (from Microsoft Research Asia) released with the paper `LayoutLMv2:
36. :doc:`LayoutLMv2 <model_doc/layoutlmv2>` (from Microsoft Research Asia) released with the paper `LayoutLMv2:
Multi-modal Pre-training for Visually-Rich Document Understanding <https://arxiv.org/abs/2012.14740>`__ by Yang Xu,
Yiheng Xu, Tengchao Lv, Lei Cui, Furu Wei, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Wanxiang Che, Min
Zhang, Lidong Zhou.
36. :doc:`LayoutXLM <model_doc/layoutlmv2>` (from Microsoft Research Asia) released with the paper `LayoutXLM:
37. :doc:`LayoutXLM <model_doc/layoutlmv2>` (from Microsoft Research Asia) released with the paper `LayoutXLM:
Multimodal Pre-training for Multilingual Visually-rich Document Understanding <https://arxiv.org/abs/2104.08836>`__
by Yiheng Xu, Tengchao Lv, Lei Cui, Guoxin Wang, Yijuan Lu, Dinei Florencio, Cha Zhang, Furu Wei.
37. :doc:`LED <model_doc/led>` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer
38. :doc:`LED <model_doc/led>` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer
<https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, Arman Cohan.
38. :doc:`Longformer <model_doc/longformer>` (from AllenAI) released with the paper `Longformer: The Long-Document
39. :doc:`Longformer <model_doc/longformer>` (from AllenAI) released with the paper `Longformer: The Long-Document
Transformer <https://arxiv.org/abs/2004.05150>`__ by Iz Beltagy, Matthew E. Peters, Arman Cohan.
39. :doc:`LUKE <model_doc/luke>` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity
40. :doc:`LUKE <model_doc/luke>` (from Studio Ousia) released with the paper `LUKE: Deep Contextualized Entity
Representations with Entity-aware Self-attention <https://arxiv.org/abs/2010.01057>`__ by Ikuya Yamada, Akari Asai,
Hiroyuki Shindo, Hideaki Takeda, Yuji Matsumoto.
40. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
41. :doc:`LXMERT <model_doc/lxmert>` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality
Encoder Representations from Transformers for Open-Domain Question Answering <https://arxiv.org/abs/1908.07490>`__
by Hao Tan and Mohit Bansal.
41. :doc:`M2M100 <model_doc/m2m_100>` (from Facebook) released with the paper `Beyond English-Centric Multilingual
42. :doc:`M2M100 <model_doc/m2m_100>` (from Facebook) released with the paper `Beyond English-Centric Multilingual
Machine Translation <https://arxiv.org/abs/2010.11125>`__ by by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi
Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman
Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin.
42. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
43. :doc:`MarianMT <model_doc/marian>` Machine translation models trained using `OPUS <http://opus.nlpl.eu/>`__ data by
Jörg Tiedemann. The `Marian Framework <https://marian-nmt.github.io/>`__ is being developed by the Microsoft
Translator Team.
43. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
44. :doc:`MBart <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Denoising Pre-training for
Neural Machine Translation <https://arxiv.org/abs/2001.08210>`__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li,
Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer.
44. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
45. :doc:`MBart-50 <model_doc/mbart>` (from Facebook) released with the paper `Multilingual Translation with Extensible
Multilingual Pretraining and Finetuning <https://arxiv.org/abs/2008.00401>`__ by Yuqing Tang, Chau Tran, Xian Li,
Peng-Jen Chen, Naman Goyal, Vishrav Chaudhary, Jiatao Gu, Angela Fan.
45. :doc:`Megatron-BERT <model_doc/megatron_bert>` (from NVIDIA) released with the paper `Megatron-LM: Training
46. :doc:`Megatron-BERT <model_doc/megatron_bert>` (from NVIDIA) released with the paper `Megatron-LM: Training
Multi-Billion Parameter Language Models Using Model Parallelism <https://arxiv.org/abs/1909.08053>`__ by Mohammad
Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
46. :doc:`Megatron-GPT2 <model_doc/megatron_gpt2>` (from NVIDIA) released with the paper `Megatron-LM: Training
47. :doc:`Megatron-GPT2 <model_doc/megatron_gpt2>` (from NVIDIA) released with the paper `Megatron-LM: Training
Multi-Billion Parameter Language Models Using Model Parallelism <https://arxiv.org/abs/1909.08053>`__ by Mohammad
Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper and Bryan Catanzaro.
47. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
48. :doc:`MPNet <model_doc/mpnet>` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted
Pre-training for Language Understanding <https://arxiv.org/abs/2004.09297>`__ by Kaitao Song, Xu Tan, Tao Qin,
Jianfeng Lu, Tie-Yan Liu.
48. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
49. :doc:`MT5 <model_doc/mt5>` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained
text-to-text transformer <https://arxiv.org/abs/2010.11934>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir
Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
49. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
50. :doc:`Pegasus <model_doc/pegasus>` (from Google) released with the paper `PEGASUS: Pre-training with Extracted
Gap-sentences for Abstractive Summarization <https://arxiv.org/abs/1912.08777>`__> by Jingqing Zhang, Yao Zhao,
Mohammad Saleh and Peter J. Liu.
50. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
51. :doc:`ProphetNet <model_doc/prophetnet>` (from Microsoft Research) released with the paper `ProphetNet: Predicting
Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan, Weizhen Qi,
Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
51. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
52. :doc:`Reformer <model_doc/reformer>` (from Google Research) released with the paper `Reformer: The Efficient
Transformer <https://arxiv.org/abs/2001.04451>`__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
52. :doc:`RemBERT <model_doc/rembert>` (from Google Research) released with the paper `Rethinking embedding coupling in
53. :doc:`RemBERT <model_doc/rembert>` (from Google Research) released with the paper `Rethinking embedding coupling in
pre-trained language models <https://arxiv.org/pdf/2010.12821.pdf>`__ by Hyung Won Chung, Thibault Févry, Henry
Tsai, M. Johnson, Sebastian Ruder.
53. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
54. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
54. :doc:`RoFormer <model_doc/roformer>` (from ZhuiyiTechnology), released together with the paper a `RoFormer:
55. :doc:`RoFormer <model_doc/roformer>` (from ZhuiyiTechnology), released together with the paper a `RoFormer:
Enhanced Transformer with Rotary Position Embedding <https://arxiv.org/pdf/2104.09864v1.pdf>`__ by Jianlin Su and
Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
55. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
56. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
56. :doc:`Splinter <model_doc/splinter>` (from Tel Aviv University), released together with the paper `Few-Shot
57. :doc:`Splinter <model_doc/splinter>` (from Tel Aviv University), released together with the paper `Few-Shot
Question Answering by Pretraining Span Selection <https://arxiv.org/abs/2101.00438>`__ by Ori Ram, Yuval Kirstain,
Jonathan Berant, Amir Globerson, Omer Levy.
57. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
58. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
Krishna, and Kurt W. Keutzer.
58. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
59. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
59. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
60. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
Francesco Piccinno and Julian Martin Eisenschlos.
60. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
61. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
61. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
62. :doc:`Vision Transformer (ViT) <model_doc/vit>` (from Google AI) released with the paper `An Image is Worth 16x16
Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`__ by Alexey Dosovitskiy,
Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias
Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
62. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
63. :doc:`VisualBERT <model_doc/visual_bert>` (from UCLA NLP) released with the paper `VisualBERT: A Simple and
Performant Baseline for Vision and Language <https://arxiv.org/pdf/1908.03557>`__ by Liunian Harold Li, Mark
Yatskar, Da Yin, Cho-Jui Hsieh, Kai-Wei Chang.
63. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
64. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
Zhou, Abdelrahman Mohamed, Michael Auli.
64. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
65. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
65. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
66. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
66. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
67. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
Zettlemoyer and Veselin Stoyanov.
67. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive
68. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
68. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
69. :doc:`XLSR-Wav2Vec2 <model_doc/xlsr_wav2vec2>` (from Facebook AI) released with the paper `Unsupervised
Cross-Lingual Representation Learning For Speech Recognition <https://arxiv.org/abs/2006.13979>`__ by Alexis
Conneau, Alexei Baevski, Ronan Collobert, Abdelrahman Mohamed, Michael Auli.
......@@ -372,6 +374,8 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| GPT Neo | ❌ | ❌ | ✅ | ❌ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| GPT-J | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Hubert | ❌ | ❌ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
......@@ -574,6 +578,7 @@ Flax), PyTorch, and/or TensorFlow.
model_doc/mt5
model_doc/gpt
model_doc/gpt2
model_doc/gptj
model_doc/gpt_neo
model_doc/hubert
model_doc/pegasus
......
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
GPT-J
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The GPT-J model was released in the `kingoflolz/mesh-transformer-jax
<https://github.com/kingoflolz/mesh-transformer-jax>`__ repository by Ben Wang and Aran Komatsuzaki. It is a GPT-2-like
causal language model trained on `the Pile <https://pile.eleuther.ai/>`__ dataset.
This model was contributed by `Stella Biderman <https://huggingface.co/stellaathena>`__.
Tips:
- Running [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) in float32 precision on GPU requires at least 24 GB of
RAM. On GPUs with less than 24 GB RAM, one should therefore load the model in half-precision:
.. code-block::
>>> from transformers import GPTJForCausalLM
>>> import torch
>>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16)
Generation
_______________________________________________________________________________________________________________________
The :meth:`~transformers.generation_utils.GenerationMixin.generate` method can be used to generate text using GPT-J
model.
.. code-block::
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
>>> prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
... "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
... "researchers was the fact that the unicorns spoke perfect English."
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,)
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
...or in float16 precision:
.. code-block::
>>> from transformers import GPTJForCausalLM, AutoTokenizer
>>> import torch
>>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16)
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
>>> prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
... "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
... "researchers was the fact that the unicorns spoke perfect English."
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids
>>> gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_length=100,)
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
GPTJConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.GPTJConfig
:members:
GPTJModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.GPTJModel
:members: forward
GPTJForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.GPTJForCausalLM
:members: forward
GPTJForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.GPTJForSequenceClassification
:members: forward
......@@ -213,6 +213,7 @@ _import_structure = {
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"],
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"],
"models.gpt_neo": ["GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTNeoConfig"],
"models.gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
"models.herbert": ["HerbertTokenizer"],
"models.hubert": ["HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "HubertConfig"],
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
......@@ -824,6 +825,15 @@ if is_torch_available():
"load_tf_weights_in_gpt_neo",
]
)
_import_structure["models.gptj"].extend(
[
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
"GPTJForSequenceClassification",
"GPTJModel",
"GPTJPreTrainedModel",
]
)
_import_structure["models.hubert"].extend(
[
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -1966,6 +1976,7 @@ if TYPE_CHECKING:
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
from .models.gpt_neo import GPT_NEO_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTNeoConfig
from .models.gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
from .models.herbert import HerbertTokenizer
from .models.hubert import HUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, HubertConfig
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
......@@ -2486,6 +2497,13 @@ if TYPE_CHECKING:
GPTNeoPreTrainedModel,
load_tf_weights_in_gpt_neo,
)
from .models.gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
GPTJForSequenceClassification,
GPTJModel,
GPTJPreTrainedModel,
)
from .models.hubert import (
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
HubertForCTC,
......
......@@ -50,6 +50,7 @@ from . import (
funnel,
gpt2,
gpt_neo,
gptj,
herbert,
hubert,
ibert,
......
......@@ -26,6 +26,7 @@ from ...file_utils import CONFIG_NAME
CONFIG_MAPPING_NAMES = OrderedDict(
[
# Add configs here
("gptj", "GPTJConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("beit", "BeitConfig"),
("rembert", "RemBertConfig"),
......@@ -96,6 +97,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
[
# Add archive maps here
("gptj", "GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("layoutlmv2", "LAYOUTLMV2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("beit", "BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rembert", "REMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
......@@ -158,6 +160,7 @@ CONFIG_ARCHIVE_MAP_MAPPING_NAMES = OrderedDict(
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("gptj", "GPT-J"),
("beit", "BeiT"),
("rembert", "RemBERT"),
("layoutlmv2", "LayoutLMv2"),
......
......@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("gptj", "GPTJModel"),
("layoutlmv2", "LayoutLMv2Model"),
("beit", "BeitModel"),
("rembert", "RemBertModel"),
......@@ -135,6 +136,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("bigbird_pegasus", "BigBirdPegasusForConditionalGeneration"),
......@@ -183,6 +185,7 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("gptj", "GPTJForCausalLM"),
("rembert", "RemBertForCausalLM"),
("roformer", "RoFormerForCausalLM"),
("bigbird_pegasus", "BigBirdPegasusForCausalLM"),
......@@ -286,6 +289,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("gptj", "GPTJForSequenceClassification"),
("layoutlmv2", "LayoutLMv2ForSequenceClassification"),
("rembert", "RemBertForSequenceClassification"),
("canine", "CanineForSequenceClassification"),
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_torch_available
_import_structure = {
"configuration_gptj": ["GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPTJConfig"],
}
if is_torch_available():
_import_structure["modeling_gptj"] = [
"GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST",
"GPTJForCausalLM",
"GPTJForSequenceClassification",
"GPTJModel",
"GPTJPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_gptj import GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP, GPTJConfig
if is_torch_available():
from .modeling_gptj import (
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST,
GPTJForCausalLM,
GPTJForSequenceClassification,
GPTJModel,
GPTJPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
# coding=utf-8
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" GPT-J model configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"EleutherAI/gpt-j-6B": "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/config.json",
# See all GPT-J models at https://huggingface.co/models?filter=gpt_j
}
class GPTJConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.GPTJModel`. It is used to
instantiate a GPT-J model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the GPT-J `gpt-j-6B
<https://huggingface.co/EleutherAI/gpt-j-6B>`__ architecture. Configuration objects inherit from
:class:`~transformers.PretrainedConfig` and can be used to control the model outputs. Read the documentation from
:class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 50400):
Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.GPTJModel`.
n_positions (:obj:`int`, `optional`, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
n_ctx (:obj:`int`, `optional`, defaults to 2048):
Dimensionality of the causal mask (usually same as n_positions).
n_embd (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the embeddings and hidden states.
n_layer (:obj:`int`, `optional`, defaults to 28):
Number of hidden layers in the Transformer encoder.
n_head (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
rotary_dim (:obj:`int`, `optional`, defaults to 64):
Number of dimensions in the embedding that Rotary Position Embedding is applied to.
n_inner (:obj:`int`, `optional`, defaults to None):
Dimensionality of the inner feed-forward layers. :obj:`None` will set it to 4 times n_embd
activation_function (:obj:`str`, `optional`, defaults to :obj:`"gelu_new"`):
Activation function, to be selected in the list :obj:`["relu", "silu", "gelu", "tanh", "gelu_new"]`.
resid_pdrop (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (:obj:`int`, `optional`, defaults to 0.1):
The dropout ratio for the embeddings.
attn_pdrop (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention.
layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5):
The epsilon to use in the layer normalization layers.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`):
Scale attention weights by dividing by sqrt(hidden_size).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
Example::
>>> from transformers import GPTJModel, GPTJConfig
>>> # Initializing a GPT-J 6B configuration
>>> configuration = GPTJConfig()
>>> # Initializing a model from the configuration
>>> model = GPTJModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "gptj"
def __init__(
self,
vocab_size=50400,
n_positions=2048,
n_ctx=2048,
n_embd=4096,
n_layer=28,
n_head=16,
rotary_dim=64,
n_inner=None,
activation_function="gelu_new",
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
scale_attn_weights=True,
gradient_checkpointing=False,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
**kwargs
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_inner = n_inner
self.rotary_dim = rotary_dim
self.activation_function = activation_function
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.gradient_checkpointing = gradient_checkpointing
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
@property
def max_position_embeddings(self):
return self.n_positions
@property
def hidden_size(self):
return self.n_embd
@property
def num_attention_heads(self):
return self.n_head
@property
def num_hidden_layers(self):
return self.n_layer
This diff is collapsed.
......@@ -32,6 +32,22 @@ class TextGenerationPipeline(Pipeline):
begging for his blessing. <eod> </s> <eos>
"""
ALLOWED_MODELS = [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"ReformerModelWithLMHead",
"GPT2LMHeadModel",
"GPTJForCausalLM",
"GPTNeoForCausalLM",
"OpenAIGPTLMHeadModel",
"CTRLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
"TFGPT2LMHeadModel",
"TFOpenAIGPTLMHeadModel",
"TFCTRLLMHeadModel",
]
def __init__(self, *args, return_full_text=True, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(
......
......@@ -1808,6 +1808,45 @@ def load_tf_weights_in_gpt_neo(*args, **kwargs):
requires_backends(load_tf_weights_in_gpt_neo, ["torch"])
GPTJ_PRETRAINED_MODEL_ARCHIVE_LIST = None
class GPTJForCausalLM:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class GPTJForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class GPTJModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class GPTJPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment