gptj.md 8.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
<!--Copyright 2021 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
11
12
13
14

鈿狅笍 Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

15
16
17
18
19
20
21
22
23
24
25
-->

# GPT-J

## Overview

The GPT-J model was released in the [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax) repository by Ben Wang and Aran Komatsuzaki. It is a GPT-2-like
causal language model trained on [the Pile](https://pile.eleuther.ai/) dataset.

This model was contributed by [Stella Biderman](https://huggingface.co/stellaathena).

26
## Usage tips
27

28
29
30
31
32
- To load [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) in float32 one would need at least 2x model size
  RAM: 1x for initial weights and another 1x to load the checkpoint. So for GPT-J it would take at least 48GB
  RAM to just load the model. To reduce the RAM usage there are a few options. The `torch_dtype` argument can be
  used to initialize the model in half-precision on a CUDA device only. There is also a fp16 branch which stores the fp16 weights,
  which could be used to further minimize the RAM usage:
33
34
35
36
37

```python
>>> from transformers import GPTJForCausalLM
>>> import torch

38
>>> device = "cuda"
Sylvain Gugger's avatar
Sylvain Gugger committed
39
>>> model = GPTJForCausalLM.from_pretrained(
40
41
42
43
...     "EleutherAI/gpt-j-6B",
...     revision="float16",
...     torch_dtype=torch.float16,
... ).to(device)
44
45
46
47
48
49
50
51
52
53
54
```

- The model should fit on 16GB GPU for inference. For training/fine-tuning it would take much more GPU RAM. Adam
  optimizer for example makes four copies of the model: model, gradients, average and squared average of the gradients.
  So it would need at least 4x model size GPU memory, even with mixed precision as gradient updates are in fp32. This
  is not including the activations and data batches, which would again require some more GPU RAM. So one should explore
  solutions such as DeepSpeed, to train/fine-tune the model. Another option is to use the original codebase to
  train/fine-tune the model on TPU and then convert the model to Transformers format for inference. Instructions for
  that could be found [here](https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md)

- Although the embedding matrix has a size of 50400, only 50257 entries are used by the GPT-2 tokenizer. These extra
55
  tokens are added for the sake of efficiency on TPUs. To avoid the mismatch between embedding matrix size and vocab
56
57
58
  size, the tokenizer for [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) contains 143 extra tokens
  `<|extratoken_1|>... <|extratoken_143|>`, so the `vocab_size` of tokenizer also becomes 50400.

59
## Usage examples
60

61
The [`~generation.GenerationMixin.generate`] method can be used to generate text using GPT-J
62
63
64
65
model.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
Sylvain Gugger's avatar
Sylvain Gugger committed
66

67
68
69
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

Sylvain Gugger's avatar
Sylvain Gugger committed
70
71
72
73
74
>>> prompt = (
...     "In a shocking finding, scientists discovered a herd of unicorns living in a remote, "
...     "previously unexplored valley, in the Andes Mountains. Even more surprising to the "
...     "researchers was the fact that the unicorns spoke perfect English."
... )
75
76
77

>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

Sylvain Gugger's avatar
Sylvain Gugger committed
78
79
80
81
82
83
>>> gen_tokens = model.generate(
...     input_ids,
...     do_sample=True,
...     temperature=0.9,
...     max_length=100,
... )
84
85
86
87
88
89
90
91
92
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
```

...or in float16 precision:

```python
>>> from transformers import GPTJForCausalLM, AutoTokenizer
>>> import torch

93
94
>>> device = "cuda"
>>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16).to(device)
95
96
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

Sylvain Gugger's avatar
Sylvain Gugger committed
97
98
99
100
101
>>> prompt = (
...     "In a shocking finding, scientists discovered a herd of unicorns living in a remote, "
...     "previously unexplored valley, in the Andes Mountains. Even more surprising to the "
...     "researchers was the fact that the unicorns spoke perfect English."
... )
102

103
>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
104

Sylvain Gugger's avatar
Sylvain Gugger committed
105
106
107
108
109
110
>>> gen_tokens = model.generate(
...     input_ids,
...     do_sample=True,
...     temperature=0.9,
...     max_length=100,
... )
111
112
113
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
```

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
## Resources

A list of official Hugging Face and community (indicated by 馃寧) resources to help you get started with GPT-J. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

<PipelineTag pipeline="text-generation"/>

- Description of [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B).
- A blog on how to [Deploy GPT-J 6B for inference using Hugging Face Transformers and Amazon SageMaker](https://huggingface.co/blog/gptj-sagemaker).
- A blog on how to [Accelerate GPT-J inference with DeepSpeed-Inference on GPUs](https://www.philschmid.de/gptj-deepspeed-inference).
- A blog post introducing [GPT-J-6B: 6B JAX-Based Transformer](https://arankomatsuzaki.wordpress.com/2021/06/04/gpt-j/). 馃寧
- A notebook for [GPT-J-6B Inference Demo](https://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb). 馃寧
- Another notebook demonstrating [Inference with GPT-J-6B](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/GPT-J-6B/Inference_with_GPT_J_6B.ipynb).  
- [Causal language modeling](https://huggingface.co/course/en/chapter7/6?fw=pt#training-a-causal-language-model-from-scratch) chapter of the 馃 Hugging Face Course.
- [`GPTJForCausalLM`] is supported by this [causal language modeling example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling#gpt-2gpt-and-causal-language-modeling), [text generation example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-generation), and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb).
- [`TFGPTJForCausalLM`] is supported by this [causal language modeling example script](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/language-modeling#run_clmpy) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).
- [`FlaxGPTJForCausalLM`] is supported by this [causal language modeling example script](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#causal-language-modeling) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/causal_language_modeling_flax.ipynb).

131
**Documentation resources**
132
133
134
- [Text classification task guide](../tasks/sequence_classification)
- [Question answering task guide](../tasks/question_answering)
- [Causal language modeling task guide](../tasks/language_modeling)
135

136
137
138
139
140
## GPTJConfig

[[autodoc]] GPTJConfig
    - all

141
142
143
<frameworkcontent>
<pt>

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
## GPTJModel

[[autodoc]] GPTJModel
    - forward

## GPTJForCausalLM

[[autodoc]] GPTJForCausalLM
    - forward

## GPTJForSequenceClassification

[[autodoc]] GPTJForSequenceClassification
    - forward

## GPTJForQuestionAnswering

[[autodoc]] GPTJForQuestionAnswering
    - forward

164
165
166
</pt>
<tf>

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
## TFGPTJModel

[[autodoc]] TFGPTJModel
    - call

## TFGPTJForCausalLM

[[autodoc]] TFGPTJForCausalLM
    - call

## TFGPTJForSequenceClassification

[[autodoc]] TFGPTJForSequenceClassification
    - call

## TFGPTJForQuestionAnswering

[[autodoc]] TFGPTJForQuestionAnswering
    - call

187
188
189
</tf>
<jax>

190
191
192
193
194
195
196
197
198
## FlaxGPTJModel

[[autodoc]] FlaxGPTJModel
    - __call__

## FlaxGPTJForCausalLM

[[autodoc]] FlaxGPTJForCausalLM
    - __call__
199
200
</jax>
</frameworkcontent>