zero3.rst 12.2 KB
Newer Older
aiss's avatar
aiss committed
1
2
ZeRO
####
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
3
4
5
6
7
8
9
10

The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across
data-parallel processes by partitioning the three model states (optimizer
states, gradients, and parameters) across data-parallel processes instead of
replicating them. By doing this, it boosts memory efficiency compared to
classic data-parallelism while retaining its computational granularity and
communication efficiency.

aiss's avatar
aiss committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#. **ZeRO Stage 1**: The optimizer states (e.g., for `Adam optimizer <https://arxiv.org/abs/1412.6980>`_, 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition.

#. **ZeRO Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states.

#. **ZeRO Stage 3**: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes.

In addition, ZeRO-3 includes the *infinity offload engine* to form
ZeRO-Infinity ([paper](https://arxiv.org/abs/2104.07857)), which can offload
all model states to both CPU and NVMe memory for huge memory savings.


For a deep dive of our algorithms, please see our `papers <https://www.deepspeed.ai/#publications>`_ on `ZeRO
<https://arxiv.org/abs/1910.02054>`_, `ZeRO-Offload
<https://arxiv.org/abs/2101.06840>`_,
and `ZeRO-Infinity <https://arxiv.org/abs/2104.07857>`_.

.. note::
    DeepSpeed first included offloading capabilities with **ZeRO-Offload**, a
    system for offloading optimizer and gradient states to CPU memory within
    ZeRO-2. **ZeRO-Infinity** is the next generation of offloading
    capabilities, accessible to ZeRO-3. ZeRO-Infinity has all of the savings
    of ZeRO-Offload, plus is able to offload more the model weights and has
    more effective bandwidth utilization and overlapping of computation and
    communication.

Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
36
37
38
39
40
41
42


Getting Started
---------------

If you are new to DeepSpeed, check out our `Getting Started <https://www.deepspeed.ai/getting-started/>`_ page.

aiss's avatar
aiss committed
43
Once you are training with DeepSpeed, enabling ZeRO-3 offload is as simple as enabling it
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
44
45
46
47
48
in your DeepSpeed configuration! Below are a few examples of ZeRO-3 configurations. Please see
our `config guide <https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training>`_
for a complete list of options for configuration and performance tuning.

.. note::
aiss's avatar
aiss committed
49
        ZeRO-Infinity and ZeRO-Offload work best with our heavily optimized
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
50
51
52
53
        :class:`deepspeed.ops.adam.DeepSpeedCPUAdam` optimizer. We recommend using
        our `optimizer config <https://www.deepspeed.ai/docs/config-json/#optimizer-parameters>`_
        to instruct :meth:`deepspeed.initialize` to build the optimizer for you.

aiss's avatar
aiss committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
ZeRO Configurations
===================

All the settings for DeepSpeed ZeRO are set with the `DeepSpeedZeroConfig`_.
The dictionary provided under the ``zero_optimization`` entry of the main
DeepSpeed configuration dict will be parsed and validated with this class.
Sub-configurations for parameter offload and optimzer offload settings are
parsed by `DeepSpeedZeroOffloadParamConfig`_ and
`DeepSpeedZeroOffloadOptimizerConfig`_.

.. _DeepSpeedZeroConfig:
.. autopydantic_model:: deepspeed.runtime.zero.config.DeepSpeedZeroConfig

.. _DeepSpeedZeroOffloadParamConfig:
.. autopydantic_model:: deepspeed.runtime.zero.config.DeepSpeedZeroOffloadParamConfig

.. _DeepSpeedZeroOffloadOptimizerConfig:
.. autopydantic_model:: deepspeed.runtime.zero.config.DeepSpeedZeroOffloadOptimizerConfig
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
72

aiss's avatar
aiss committed
73
74
75

Example ZeRO-3 Configurations
=============================
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

#. Use ZeRO to partition the optimizer states (stage 1), gradients (stage 2),
   and parameters (stage 3).

    .. code-block:: python
        :emphasize-lines: 3

        {
            "zero_optimization": {
                "stage": 3,
            },
            "fp16": {
                "enabled": true
            },
            "optimizer": {
                "type": "AdamW",
                "params": {
                "lr": 0.001,
                "betas": [
                    0.8,
                    0.999
                ],
                "eps": 1e-8,
                "weight_decay": 3e-7
                }
            },
            ...
        }


aiss's avatar
aiss committed
106
#. Additionally offload the optimizer states and computations to the CPU with ZeRO-Infinity.
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
107
108
109
110
111
112

    .. code-block:: python

        {
            "zero_optimization": {
                "stage": 3,
aiss's avatar
aiss committed
113
114
115
                "offload_optimizer": {
                    "device": "cpu"
                }
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
116
117
118
119
120
121
122
123
124
125
126
127
            },
            ...
        }


#. Save even more memory by offloading parameters to the CPU memory.

    .. code-block:: python

        {
            "zero_optimization": {
                "stage": 3,
aiss's avatar
aiss committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                "offload_optimizer": {
                    "device": "cpu"
                }
                "offload_param": {
                    "device": "cpu"
                }
            },
            ...
        }


#. Save even MORE memory by offloading to NVMe (if available on your system):

    .. code-block:: python

        {
            "zero_optimization": {
                "stage": 3,
                "offload_optimizer": {
                    "device": "nvme",
                    "nvme_path": "/nvme_data"
                }
                "offload_param": {
                    "device": "nvme",
                    "nvme_path": "/nvme_data"
                }
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
            },
            ...
        }



Assumptions
===========

DeepSpeed automatically coordinates the collection (*i.e.,* all-gather),
partitioning (*i.e.,* scatter), and offloading of parameters at the
granularity of (sub)module ``forward()`` methods. The backward pass is
handled similarly. This strategy has two underlying assumptions:

#. The forward and backward passes of submodules must individually fit in device memory.
aiss's avatar
aiss committed
169
170
171
   If this not the case, :class:`deepspeed.zero.TiledLinear` implements
   **memory-centric tiling** and works with ZeRO-3 to break linear layers
   into a sequence of smaller submodules that can fit in memory.
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

#. A module's parameters are only accessed within its own ``__init__`` and ``forward()`` methods.
   Otherwise, DeepSpeed must be instructed to collect and re-partition the parameter.
   See :ref:`external-parameters` for manually coordinating parameters.


Constructing Massive Models
---------------------------

ZeRO-3 enables massive models whose parameters exceed the size of individual
nodes in a system. For the typical case of training without model parallelism,
you can simply allocate your model in our context:

.. code-block:: python

    with deepspeed.zero.Init():
        model = MyLargeModel()


.. autoclass:: deepspeed.zero.Init
    :members:


.. _external-parameters:

Manual Parameter Coordination
-----------------------------

Most models require no modification to be trained with ZeRO-3. However, in
some cases one may need to access model weights outside of the training loop,
or to share weights across submodules during training. DeepSpeed has
several mechanisms to coordinate partitioned weights for ZeRO-3.


Gathering Parameters
====================

DeepSpeed provides mechanisms for collecting (or *gathering*) a partitioned parameter.

Some models partitioned with :class:`deepspeed.zero.Init` may need to access
a module’s weights outside of the class constructor or its ``forward()``
Stas Bekman's avatar
Stas Bekman committed
213
214
method. We refer to these weights as **external parameters**, since these
parameters are accessed outside of the module that created them. To do so, use
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
215
216
217
218
219
220
221
222
223
:class:`deepspeed.zero.GatheredParameters` or :meth:`deepspeed.zero.register_external_parameter`.

.. autoclass:: deepspeed.zero.GatheredParameters
    :members:


Registering External Parameters
===============================

aiss's avatar
aiss committed
224
225
226
227
228
ZeRO-3 will automatically collect and partition the model parameters as they
are needed during the forward and backward passes. However, in some cases a
parameter may be used outside of its module's forward pass. We call these
*external* parameters. ZeRO-3 can coordinate these parameters if they are
registered either automatically or manually.
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
229
230


aiss's avatar
aiss committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
.. note::
    DeepSpeed version ``0.3.15`` includes automatic external parameter
    discovery and registration to support the most common cases. Parameters
    can still be manually registered if they cannot be automatically
    detected.


DeepSpeed can automatically detect the following external parameter scenarios:


#. Parameter access: consider the following pattern common in language models such as GPT:

   The tensor ``embeddings.weight`` is used in both ``embeddings.forward()`` and
   ``compute_logits()``. We call ``embeddings.weight`` an *external* parameter
   because it is used in the training loop outside of its owning module's
   forward pass.


   .. code-block:: python

       class LanguageModel(torch.nn.Module):
           ...
           def forward(self, inputs):
               embeds = self.embeddings(inputs)
               ...
               logits = compute_logits(output, self.embeddings.weight)
               ...


#. Returning a parameter:

   ``CustomLinear`` returns both an output and its own ``bias`` parameter. DeepSpeed
   will detect the external ``bias`` parameter and register it with submodules that
   use ``CustomLinear``.

   .. code-block:: python

       class CustomLinear(torch.nn.Linear):
           def forward(self, *input):
               output = super().forward(*input)
               return output, self.bias
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
272
273
274
275
276
277



.. autofunction:: deepspeed.zero.register_external_parameter

.. autofunction:: deepspeed.zero.unregister_external_parameter
aiss's avatar
aiss committed
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295


Memory-Centric Tiling
---------------------

To reduce the working memory requirements of DL training for large models,
ZeRO-Infinity includes technique called *memory-centric tiling* that exploits
the data fetch and release pattern of ZeRO-3 to reduce the working memory
requirements by breaking down a large operator into smaller tiles that can be
executed sequentially. When combined with ZeRO-3, the parameter and gradients
of each tile can be fetched and released one at a time, reducing the working
memory proportional to the number of tiles. Therefore, ZeRO-Infinity can
support operators of arbitrary sizes, without refactoring for model
parallelism to fit them in limited GPU memory.


.. autoclass:: deepspeed.zero.TiledLinear
    :members:
aiss's avatar
aiss committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333


Debugging
---------

Debugging ZeRO training is complicated by the partitioning of parameters, gradients, and optimizer states. None of these 3 groups of tensors (model states) can be normally accessed because of that. To overcome that DeepSpeed provides the following routines for accessing individual model states in their unpartitioned form.

Important: Please note that these utilities must be called by all processes participating in the training, even if you decide to do something with the result only in the main process. If all processes don't participate these utilities will hang waiting for all processes to send their contribution.

Additionally, you must be aware that these routines return correct data only in specific phases of the training. So for examples the gradients are valid after ``backward`` and before ``step``. The optimizer states are updated after ``step``. Same goes for fp32 master weights.

.. autofunction:: deepspeed.utils.safe_get_full_fp32_param

.. autofunction:: deepspeed.utils.safe_get_full_grad

.. autofunction:: deepspeed.utils.safe_get_full_optimizer_state


These routines can be used in a training loop as shown in the following snippet.

.. code-block:: python

    backward(loss)
    [...]
    from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
    for n, lp in model.named_parameters():
        # 1. gradient lookup
        # For zero1 and zero2, gradient lookup must be called after `backward` and before `step`
        # For zero3, gradient lookup must be called after `backward`
        hp_grad = safe_get_full_grad(lp)

        # 2. fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
        hp = safe_get_full_fp32_param(lp)
        exp_avg = safe_get_full_optimizer_state(lp, "exp_avg")
        exp_avg_sq = safe_get_full_optimizer_state(lp, "exp_avg_sq")

    [...]
    optimizer.step()