trainer.md 24.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
<!--Copyright 2020 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
11
12
13
14

鈿狅笍 Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

15
16
17
18
-->

# Trainer

Steven Liu's avatar
Steven Liu committed
19
The [`Trainer`] class provides an API for feature-complete training in PyTorch, and it supports distributed training on multiple GPUs/TPUs, mixed precision for [NVIDIA GPUs](https://nvidia.github.io/apex/), [AMD GPUs](https://rocm.docs.amd.com/en/latest/rocm.html), and [`torch.amp`](https://pytorch.org/docs/stable/amp.html) for PyTorch. [`Trainer`] goes hand-in-hand with the [`TrainingArguments`] class, which offers a wide range of options to customize how a model is trained. Together, these two classes provide a complete training API.
20

Steven Liu's avatar
Steven Liu committed
21
[`Seq2SeqTrainer`] and [`Seq2SeqTrainingArguments`] inherit from the [`Trainer`] and [`TrainingArgument`] classes and they're adapted for training models for sequence-to-sequence tasks such as summarization or translation.
22
23
24
25

<Tip warning={true}>

The [`Trainer`] class is optimized for 馃 Transformers models and can have surprising behaviors
Steven Liu's avatar
Steven Liu committed
26
when used with other models. When using it with your own model, make sure:
27

Steven Liu's avatar
Steven Liu committed
28
- your model always return tuples or subclasses of [`~utils.ModelOutput`]
29
30
- your model can compute the loss if a `labels` argument is provided and that loss is returned as the first
  element of the tuple (if your model returns tuples)
Steven Liu's avatar
Steven Liu committed
31
- your model can accept multiple label arguments (use `label_names` in [`TrainingArguments`] to indicate their name to the [`Trainer`]) but none of them should be named `"label"`
32
33
34

</Tip>

Steven Liu's avatar
Steven Liu committed
35
## Trainer[[api-reference]]
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

[[autodoc]] Trainer
    - all

## Seq2SeqTrainer

[[autodoc]] Seq2SeqTrainer
    - evaluate
    - predict

## TrainingArguments

[[autodoc]] TrainingArguments
    - all

## Seq2SeqTrainingArguments

[[autodoc]] Seq2SeqTrainingArguments
    - all

56
57
58
59
60
61
62
## Specific GPUs Selection

Let's discuss how you can tell your program which GPUs are to be used and in what order.

When using [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) to use only a subset of your GPUs, you simply specify the number of GPUs to use. For example, if you have 4 GPUs, but you wish to use the first 2 you can do:

```bash
63
torchrun --nproc_per_node=2  trainer-program.py ...
64
65
66
```

if you have either [`accelerate`](https://github.com/huggingface/accelerate) or [`deepspeed`](https://github.com/microsoft/DeepSpeed) installed you can also accomplish the same by using one of:
67

68
69
70
71
72
73
74
75
```bash
accelerate launch --num_processes 2 trainer-program.py ...
```

```bash
deepspeed --num_gpus 2 trainer-program.py ...
```

Peter Pan's avatar
Peter Pan committed
76
You don't need to use the Accelerate or [the Deepspeed integration](deepspeed) features to use these launchers.
77
78
79
80
81
82
83
84
85
86
87
88
89


Until now you were able to tell the program how many GPUs to use. Now let's discuss how to select specific GPUs and control their order.

The following environment variables help you control which GPUs to use and their order.

**`CUDA_VISIBLE_DEVICES`**

If you have multiple GPUs and you'd like to use only 1 or a few of those GPUs, set the environment variable `CUDA_VISIBLE_DEVICES` to a list of the GPUs to be used.

For example, let's say you have 4 GPUs: 0, 1, 2 and 3. To run only on the physical GPUs 0 and 2, you can do:

```bash
90
CUDA_VISIBLE_DEVICES=0,2 torchrun trainer-program.py ...
91
92
93
94
95
96
97
```

So now pytorch will see only 2 GPUs, where your physical GPUs 0 and 2 are mapped to `cuda:0` and `cuda:1` correspondingly.

You can even change their order:

```bash
98
CUDA_VISIBLE_DEVICES=2,0 torchrun trainer-program.py ...
99
100
101
102
103
```

Here your physical GPUs 0 and 2 are mapped to `cuda:1` and `cuda:0` correspondingly.

The above examples were all for `DistributedDataParallel` use pattern, but the same method works for [`DataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html) as well:
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
```bash
CUDA_VISIBLE_DEVICES=2,0 python trainer-program.py ...
```

To emulate an environment without GPUs simply set this environment variable to an empty value like so:

```bash
CUDA_VISIBLE_DEVICES= python trainer-program.py ...
```

As with any environment variable you can, of course, export those instead of adding these to the command line, as in:


```bash
export CUDA_VISIBLE_DEVICES=0,2
120
torchrun trainer-program.py ...
121
122
123
124
125
126
127
128
```

but this approach can be confusing since you may forget you set up the environment variable earlier and not understand why the wrong GPUs are used. Therefore, it's a common practice to set the environment variable just for a specific run on the same command line as it's shown in most examples of this section.

**`CUDA_DEVICE_ORDER`**

There is an additional environment variable `CUDA_DEVICE_ORDER` that controls how the physical devices are ordered. The two choices are:

129
1. ordered by PCIe bus IDs (matches `nvidia-smi` and `rocm-smi`'s order) - this is the default.
130
131
132
133
134
135
136
137
138
139
140

```bash
export CUDA_DEVICE_ORDER=PCI_BUS_ID
```

2. ordered by GPU compute capabilities

```bash
export CUDA_DEVICE_ORDER=FASTEST_FIRST
```

141
Most of the time you don't need to care about this environment variable, but it's very helpful if you have a lopsided setup where you have an old and a new GPUs physically inserted in such a way so that the slow older card appears to be first. One way to fix that is to swap the cards. But if you can't swap the cards (e.g., if the cooling of the devices gets impacted) then setting `CUDA_DEVICE_ORDER=FASTEST_FIRST` will always put the newer faster card first. It'll be somewhat confusing though since `nvidia-smi` (or `rocm-smi`) will still report them in the PCIe order.
142
143
144
145
146
147
148
149
150
151

The other solution to swapping the order is to use:

```bash
export CUDA_VISIBLE_DEVICES=1,0
```
In this example we are working with just 2 GPUs, but of course the same would apply to as many GPUs as your computer has.

Also if you do set this environment variable it's the best to set it in your `~/.bashrc` file or some other startup config file and forget about it.

152
153
154
155
156
## Trainer Integrations

The [`Trainer`] has been extended to support libraries that may dramatically improve your training
time and fit much bigger models.

157
Currently it supports third party solutions, [DeepSpeed](https://github.com/microsoft/DeepSpeed) and [PyTorch FSDP](https://pytorch.org/docs/stable/fsdp.html), which implement parts of the paper [ZeRO: Memory Optimizations
158
159
Toward Training Trillion Parameter Models, by Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He](https://arxiv.org/abs/1910.02054).

160
This provided support is new and experimental as of this writing. While the support for DeepSpeed and PyTorch FSDP is active and we welcome issues around it, we don't support the FairScale integration anymore since it has been integrated in PyTorch main (see the [PyTorch FSDP integration](#pytorch-fully-sharded-data-parallel))
161
162
163
164
165

<a id='zero-install-notes'></a>

### CUDA Extension Installation Notes

166
As of this writing, Deepspeed require compilation of CUDA C++ code, before it can be used.
167

168
While all installation issues should be dealt with through the corresponding GitHub Issues of [Deepspeed](https://github.com/microsoft/DeepSpeed/issues), there are a few common issues that one may encounter while building
169
170
any PyTorch extension that needs to build CUDA extensions.

171
Therefore, if you encounter a CUDA-related build issue while doing the following:
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

```bash
pip install deepspeed
```

please, read the following notes first.

In these notes we give examples for what to do when `pytorch` has been built with CUDA `10.2`. If your situation is
different remember to adjust the version number to the one you are after.

#### Possible problem #1

While, Pytorch comes with its own CUDA toolkit, to build these two projects you must have an identical version of CUDA
installed system-wide.

For example, if you installed `pytorch` with `cudatoolkit==10.2` in the Python environment, you also need to have
CUDA `10.2` installed system-wide.

The exact location may vary from system to system, but `/usr/local/cuda-10.2` is the most common location on many
Unix systems. When CUDA is correctly set up and added to the `PATH` environment variable, one can find the
installation location by doing:

```bash
which nvcc
```

If you don't have CUDA installed system-wide, install it first. You will find the instructions by using your favorite
search engine. For example, if you're on Ubuntu you may want to search for: [ubuntu cuda 10.2 install](https://www.google.com/search?q=ubuntu+cuda+10.2+install).

#### Possible problem #2

Another possible common problem is that you may have more than one CUDA toolkit installed system-wide. For example you
may have:

```bash
/usr/local/cuda-10.2
/usr/local/cuda-11.0
```

Now, in this situation you need to make sure that your `PATH` and `LD_LIBRARY_PATH` environment variables contain
the correct paths to the desired CUDA version. Typically, package installers will set these to contain whatever the
last version was installed. If you encounter the problem, where the package build fails because it can't find the right
CUDA version despite you having it installed system-wide, it means that you need to adjust the 2 aforementioned
environment variables.

First, you may look at their contents:

```bash
echo $PATH
echo $LD_LIBRARY_PATH
```

so you get an idea of what is inside.

It's possible that `LD_LIBRARY_PATH` is empty.

`PATH` lists the locations of where executables can be found and `LD_LIBRARY_PATH` is for where shared libraries
are to looked for. In both cases, earlier entries have priority over the later ones. `:` is used to separate multiple
entries.

Now, to tell the build program where to find the specific CUDA toolkit, insert the desired paths to be listed first by
doing:

```bash
export PATH=/usr/local/cuda-10.2/bin:$PATH
export LD_LIBRARY_PATH=/usr/local/cuda-10.2/lib64:$LD_LIBRARY_PATH
```

Note that we aren't overwriting the existing values, but prepending instead.

Of course, adjust the version number, the full path if need be. Check that the directories you assign actually do
exist. `lib64` sub-directory is where the various CUDA `.so` objects, like `libcudart.so` reside, it's unlikely
that your system will have it named differently, but if it is adjust it to reflect your reality.


#### Possible problem #3

Some older CUDA versions may refuse to build with newer compilers. For example, you my have `gcc-9` but it wants
`gcc-7`.

There are various ways to go about it.

If you can install the latest CUDA toolkit it typically should support the newer compiler.

Alternatively, you could install the lower version of the compiler in addition to the one you already have, or you may
already have it but it's not the default one, so the build system can't see it. If you have `gcc-7` installed but the
build system complains it can't find it, the following might do the trick:

```bash
sudo ln -s /usr/bin/gcc-7  /usr/local/cuda-10.2/bin/gcc
sudo ln -s /usr/bin/g++-7  /usr/local/cuda-10.2/bin/g++
```

Here, we are making a symlink to `gcc-7` from `/usr/local/cuda-10.2/bin/gcc` and since
`/usr/local/cuda-10.2/bin/` should be in the `PATH` environment variable (see the previous problem's solution), it
should find `gcc-7` (and `g++7`) and then the build will succeed.

As always make sure to edit the paths in the example to match your situation.


272
273
274
275
276
277
278
279
### PyTorch Fully Sharded Data parallel

To accelerate training huge models on larger batch sizes, we can use a fully sharded data parallel model.
This type of data parallel paradigm enables fitting more data and larger models by sharding the optimizer states, gradients and parameters.
To read more about it and the benefits, check out the [Fully Sharded Data Parallel blog](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/).
We have integrated the latest PyTorch's Fully Sharded Data Parallel (FSDP) training feature.
All you need to do is enable it through the config.

280
**Required PyTorch version for FSDP support**: PyTorch >=2.1.0
281
282
283
284
285
286
287
288

**Usage**:

- Make sure you have added the distributed launcher
`-m torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE` if you haven't been using it already.

- **Sharding Strategy**: 
  - FULL_SHARD : Shards optimizer states + gradients + model parameters across data parallel workers/GPUs.
289
    For this, add `--fsdp full_shard` to the command line arguments. 
290
291
  - SHARD_GRAD_OP : Shards optimizer states + gradients across data parallel workers/GPUs.
    For this, add `--fsdp shard_grad_op` to the command line arguments.
Sourab Mangrulkar's avatar
Sourab Mangrulkar committed
292
  - NO_SHARD : No sharding. For this, add `--fsdp no_shard` to the command line arguments.
293
294
  - HYBRID_SHARD : No sharding. For this, add `--fsdp hybrid_shard` to the command line arguments.
  - HYBRID_SHARD_ZERO2 : No sharding. For this, add `--fsdp hybrid_shard_zero2` to the command line arguments.
295
- To offload the parameters and gradients to the CPU, 
296
297
298
  add `--fsdp "full_shard offload"` or `--fsdp "shard_grad_op offload"` to the command line arguments.
- To automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`, 
  add `--fsdp "full_shard auto_wrap"` or `--fsdp "shard_grad_op auto_wrap"` to the command line arguments.
299
- To enable both CPU offloading and auto wrapping, 
300
301
302
303
  add `--fsdp "full_shard offload auto_wrap"` or `--fsdp "shard_grad_op offload auto_wrap"` to the command line arguments.
- Remaining FSDP config is passed via `--fsdp_config <path_to_fsdp_config.json>`. It is either a location of
  FSDP json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`. 
  - If auto wrapping is enabled, you can either use transformer based auto wrap policy or size based auto wrap policy.
304
    - For transformer based auto wrap policy, it is recommended to specify `transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
305
306
307
308
309
      This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
      This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
      Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers. 
      Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit.
      Therefore, use this for transformer based models.
310
    - For size based auto wrap policy, please add `min_num_params` in the config file. 
311
      It specifies FSDP's minimum number of parameters for auto wrapping.
312
  - `backward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters. 
313
314
    `backward_pre` and `backward_pos` are available options. 
    For more information refer `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`
315
  - `forward_prefetch` can be specified in the config file. It controls when to prefetch next set of parameters. 
316
317
318
    If `"True"`, FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. 
  - `limit_all_gathers` can be specified in the config file. 
    If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers.
319
320
321
322
  - `activation_checkpointing` can be specified in the config file.
    If `"True"`, FSDP activation checkpointing is a technique to reduce memory usage by clearing activations of
    certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time
    for reduced memory usage.
323
324
325
326
327
328
329
330
331
332
333
334
335
336
  - `use_orig_params` can be specified in the config file. 
    If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. This also enables to have different optimizer param groups. This should be `True` when creating optimizer object before preparing/wrapping the model with FSDP.
    Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). 

**Saving and loading**
Saving entire intermediate checkpoints using `FULL_STATE_DICT` state_dict_type with CPU offloading on rank 0 takes a lot of time and often results in NCCL Timeout errors due to indefinite hanging during broadcasting. However, at the end of training, we want the whole model state dict instead of the sharded state dict which is only compatible with FSDP. Use `SHARDED_STATE_DICT` (default) state_dict_type to save the intermediate checkpoints and optimizer states in this format recommended by the PyTorch team. 

Saving the final checkpoint in transformers format using default `safetensors` format requires below changes.
```python
if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

trainer.save_model(script_args.output_dir)
```
337
338

**Few caveats to be aware of**
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
- it is incompatible with `generate`, thus is incompatible with `--predict_with_generate` 
  in all seq2seq/clm scripts (translation/summarization/clm etc.).  
  Please refer issue [#21667](https://github.com/huggingface/transformers/issues/21667)

### PyTorch/XLA Fully Sharded Data parallel

For all the TPU users, great news! PyTorch/XLA now supports FSDP.
All the latest Fully Sharded Data Parallel (FSDP) training are supported.
For more information refer to the [Scaling PyTorch models on Cloud TPUs with FSDP](https://pytorch.org/blog/scaling-pytorch-models-on-cloud-tpus-with-fsdp/) and [PyTorch/XLA implementation of FSDP](https://github.com/pytorch/xla/tree/master/torch_xla/distributed/fsdp)
All you need to do is enable it through the config.

**Required PyTorch/XLA version for FSDP support**: >=2.0

**Usage**:

Pass `--fsdp "full shard"` along with following changes to be made in `--fsdp_config <path_to_fsdp_config.json>`:
- `xla` should be set to `True` to enable PyTorch/XLA FSDP.
- `xla_fsdp_settings` The value is a dictionary which stores the XLA FSDP wrapping parameters.
  For a complete list of options, please see [here](
  https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py).
- `xla_fsdp_grad_ckpt`. When `True`, uses gradient checkpointing over each nested XLA FSDP wrapped layer. 
  This setting can only be used when the xla flag is set to true, and an auto wrapping policy is specified through
361
  `min_num_params` or `transformer_layer_cls_to_wrap`. 
362
- You can either use transformer based auto wrap policy or size based auto wrap policy.
363
  - For transformer based auto wrap policy, it is recommended to specify `transformer_layer_cls_to_wrap` in the config file. If not specified, the default value is `model._no_split_modules` when available.
364
365
366
367
368
    This specifies the list of transformer layer class name (case-sensitive) to wrap ,e.g, [`BertLayer`], [`GPTJBlock`], [`T5Block`] ....
    This is important because submodules that share weights (e.g., embedding layer) should not end up in different FSDP wrapped units.
    Using this policy, wrapping happens for each block containing Multi-Head Attention followed by couple of MLP layers. 
    Remaining layers including the shared embeddings are conveniently wrapped in same outermost FSDP unit.
    Therefore, use this for transformer based models.
369
  - For size based auto wrap policy, please add `min_num_params` in the config file. 
370
371
    It specifies FSDP's minimum number of parameters for auto wrapping.

372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
### Using Trainer for accelerated PyTorch Training on Mac 

With PyTorch v1.12 release, developers and researchers can take advantage of Apple silicon GPUs for significantly faster model training. 
This unlocks the ability to perform machine learning workflows like prototyping and fine-tuning locally, right on Mac.
Apple's Metal Performance Shaders (MPS) as a backend for PyTorch enables this and can be used via the new `"mps"` device. 
This will map computational graphs and primitives on the MPS Graph framework and tuned kernels provided by MPS.
For more information please refer official documents [Introducing Accelerated PyTorch Training on Mac](https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/)
and [MPS BACKEND](https://pytorch.org/docs/stable/notes/mps.html). 

<Tip warning={false}>

We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing) on your MacOS machine. 
It has major fixes related to model correctness and performance improvements for transformer based models.
Please refer to https://github.com/pytorch/pytorch/issues/82707 for more details.

</Tip>

**Benefits of Training and Inference using Apple Silicon Chips**

1. Enables users to train larger networks or batch sizes locally
2. Reduces data retrieval latency and provides the GPU with direct access to the full memory store due to unified memory architecture. 
Therefore, improving end-to-end performance.
3. Reduces costs associated with cloud-based development or the need for additional local GPUs.

**Pre-requisites**: To install torch with mps support, 
please follow this nice medium article [GPU-Acceleration Comes to PyTorch on M1 Macs](https://medium.com/towards-data-science/gpu-acceleration-comes-to-pytorch-on-m1-macs-195c399efcc1).

**Usage**:
401
402
`mps` device will be used by default if available similar to the way `cuda` device is used.
Therefore, no action from user is required. 
403
For example, you can run the official Glue text classififcation task (from the root folder) using Apple Silicon GPU with below command:
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432

```bash
export TASK_NAME=mrpc

python examples/pytorch/text-classification/run_glue.py \
  --model_name_or_path bert-base-cased \
  --task_name $TASK_NAME \
  --do_train \
  --do_eval \
  --max_seq_length 128 \
  --per_device_train_batch_size 32 \
  --learning_rate 2e-5 \
  --num_train_epochs 3 \
  --output_dir /tmp/$TASK_NAME/ \
  --overwrite_output_dir
```

**A few caveats to be aware of**

1. Some PyTorch operations have not been implemented in mps and will throw an error. 
One way to get around that is to set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1`, 
which will fallback to CPU for these operations. It still throws a UserWarning however.
2. Distributed setups `gloo` and `nccl` are not working with `mps` device. 
This means that currently only single GPU of `mps` device type can be used.

Finally, please, remember that, 馃 `Trainer` only integrates MPS backend, therefore if you
have any problems or questions with regards to MPS backend usage, please, 
file an issue with [PyTorch GitHub](https://github.com/pytorch/pytorch/issues).

433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
Sections that were moved:

[ <a href="./deepspeed#deepspeed-trainer-integration">DeepSpeed</a><a id="deepspeed"></a>
| <a href="./deepspeed#deepspeed-installation">Installation</a><a id="installation"></a>
| <a href="./deepspeed#deepspeed-multi-gpu">Deployment with multiple GPUs</a><a id="deployment-with-multiple-gpus"></a>
| <a href="./deepspeed#deepspeed-one-gpu">Deployment with one GPU</a><a id="deployment-with-one-gpu"></a>
| <a href="./deepspeed#deepspeed-notebook">Deployment in Notebooks</a><a id="deployment-in-notebooks"></a>
| <a href="./deepspeed#deepspeed-config">Configuration</a><a id="configuration"></a>
| <a href="./deepspeed#deepspeed-config-passing">Passing Configuration</a><a id="passing-configuration"></a>
| <a href="./deepspeed#deepspeed-config-shared">Shared Configuration</a><a id="shared-configuration"></a>
| <a href="./deepspeed#deepspeed-zero">ZeRO</a><a id="zero"></a>
| <a href="./deepspeed#deepspeed-zero2-config">ZeRO-2 Config</a><a id="zero-2-config"></a>
| <a href="./deepspeed#deepspeed-zero3-config">ZeRO-3 Config</a><a id="zero-3-config"></a>
| <a href="./deepspeed#deepspeed-nvme">NVMe Support</a><a id="nvme-support"></a>
| <a href="./deepspeed#deepspeed-zero2-zero3-performance">ZeRO-2 vs ZeRO-3 Performance</a><a id="zero-2-vs-zero-3-performance"></a>
| <a href="./deepspeed#deepspeed-zero2-example">ZeRO-2 Example</a><a id="zero-2-example"></a>
| <a href="./deepspeed#deepspeed-zero3-example">ZeRO-3 Example</a><a id="zero-3-example"></a>
| <a href="./deepspeed#deepspeed-optimizer">Optimizer</a><a id="optimizer"></a>
| <a href="./deepspeed#deepspeed-scheduler">Scheduler</a><a id="scheduler"></a>
| <a href="./deepspeed#deepspeed-fp32">fp32 Precision</a><a id="fp32-precision"></a>
| <a href="./deepspeed#deepspeed-amp">Automatic Mixed Precision</a><a id="automatic-mixed-precision"></a>
| <a href="./deepspeed#deepspeed-bs">Batch Size</a><a id="batch-size"></a>
| <a href="./deepspeed#deepspeed-grad-acc">Gradient Accumulation</a><a id="gradient-accumulation"></a>
| <a href="./deepspeed#deepspeed-grad-clip">Gradient Clipping</a><a id="gradient-clipping"></a>
| <a href="./deepspeed#deepspeed-weight-extraction">Getting The Model Weights Out</a><a id="getting-the-model-weights-out"></a>
]