gptj.mdx 5.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
<!--Copyright 2021 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# GPT-J

## Overview

The GPT-J model was released in the [kingoflolz/mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax) repository by Ben Wang and Aran Komatsuzaki. It is a GPT-2-like
causal language model trained on [the Pile](https://pile.eleuther.ai/) dataset.

This model was contributed by [Stella Biderman](https://huggingface.co/stellaathena).

Tips:

- To load [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) in float32 one would need at least 2x model size CPU
  RAM: 1x for initial weights and another 1x to load the checkpoint. So for GPT-J it would take at least 48GB of CPU
  RAM to just load the model. To reduce the CPU RAM usage there are a few options. The `torch_dtype` argument can be
  used to initialize the model in half-precision. And the `low_cpu_mem_usage` argument can be used to keep the RAM
  usage to 1x. There is also a [fp16 branch](https://huggingface.co/EleutherAI/gpt-j-6B/tree/float16) which stores
  the fp16 weights, which could be used to further minimize the RAM usage. Combining all this it should take roughly
  12.1GB of CPU RAM to load the model.

```python
>>> from transformers import GPTJForCausalLM
>>> import torch

Sylvain Gugger's avatar
Sylvain Gugger committed
36
37
38
>>> model = GPTJForCausalLM.from_pretrained(
...     "EleutherAI/gpt-j-6B", revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True
... )
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
```

- The model should fit on 16GB GPU for inference. For training/fine-tuning it would take much more GPU RAM. Adam
  optimizer for example makes four copies of the model: model, gradients, average and squared average of the gradients.
  So it would need at least 4x model size GPU memory, even with mixed precision as gradient updates are in fp32. This
  is not including the activations and data batches, which would again require some more GPU RAM. So one should explore
  solutions such as DeepSpeed, to train/fine-tune the model. Another option is to use the original codebase to
  train/fine-tune the model on TPU and then convert the model to Transformers format for inference. Instructions for
  that could be found [here](https://github.com/kingoflolz/mesh-transformer-jax/blob/master/howto_finetune.md)

- Although the embedding matrix has a size of 50400, only 50257 entries are used by the GPT-2 tokenizer. These extra
  tokens are added for the sake of efficiency on TPUs. To avoid the mis-match between embedding matrix size and vocab
  size, the tokenizer for [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) contains 143 extra tokens
  `<|extratoken_1|>... <|extratoken_143|>`, so the `vocab_size` of tokenizer also becomes 50400.

### Generation

The [`~generation_utils.GenerationMixin.generate`] method can be used to generate text using GPT-J
model.

```python
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
Sylvain Gugger's avatar
Sylvain Gugger committed
61

62
63
64
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

Sylvain Gugger's avatar
Sylvain Gugger committed
65
66
67
68
69
>>> prompt = (
...     "In a shocking finding, scientists discovered a herd of unicorns living in a remote, "
...     "previously unexplored valley, in the Andes Mountains. Even more surprising to the "
...     "researchers was the fact that the unicorns spoke perfect English."
... )
70
71
72

>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

Sylvain Gugger's avatar
Sylvain Gugger committed
73
74
75
76
77
78
>>> gen_tokens = model.generate(
...     input_ids,
...     do_sample=True,
...     temperature=0.9,
...     max_length=100,
... )
79
80
81
82
83
84
85
86
87
88
89
90
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
```

...or in float16 precision:

```python
>>> from transformers import GPTJForCausalLM, AutoTokenizer
>>> import torch

>>> model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16)
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

Sylvain Gugger's avatar
Sylvain Gugger committed
91
92
93
94
95
>>> prompt = (
...     "In a shocking finding, scientists discovered a herd of unicorns living in a remote, "
...     "previously unexplored valley, in the Andes Mountains. Even more surprising to the "
...     "researchers was the fact that the unicorns spoke perfect English."
... )
96
97
98

>>> input_ids = tokenizer(prompt, return_tensors="pt").input_ids

Sylvain Gugger's avatar
Sylvain Gugger committed
99
100
101
102
103
104
>>> gen_tokens = model.generate(
...     input_ids,
...     do_sample=True,
...     temperature=0.9,
...     max_length=100,
... )
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
>>> gen_text = tokenizer.batch_decode(gen_tokens)[0]
```

## GPTJConfig

[[autodoc]] GPTJConfig
    - all

## GPTJModel

[[autodoc]] GPTJModel
    - forward

## GPTJForCausalLM

[[autodoc]] GPTJForCausalLM
    - forward

## GPTJForSequenceClassification

[[autodoc]] GPTJForSequenceClassification
    - forward

## GPTJForQuestionAnswering

[[autodoc]] GPTJForQuestionAnswering
    - forward

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
## TFGPTJModel

[[autodoc]] TFGPTJModel
    - call

## TFGPTJForCausalLM

[[autodoc]] TFGPTJForCausalLM
    - call

## TFGPTJForSequenceClassification

[[autodoc]] TFGPTJForSequenceClassification
    - call

## TFGPTJForQuestionAnswering

[[autodoc]] TFGPTJForQuestionAnswering
    - call

153
154
155
156
157
158
159
160
161
## FlaxGPTJModel

[[autodoc]] FlaxGPTJModel
    - __call__

## FlaxGPTJForCausalLM

[[autodoc]] FlaxGPTJForCausalLM
    - __call__