distributed.rst 10.2 KB
Newer Older
yuguo's avatar
yuguo committed
1
2
3
oneflow.distributed
=========================================================

yuguo's avatar
yuguo committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
.. note ::
    Please refer to `OneFlow Distributed Overview <https://docs.oneflow.org/master/parallelism/01_introduction.html>`__
    for a brief introduction to all features related to distributed training.

OneFlow provides two ways to accomplish `Distributed Training`:

- The first way is that users are recommended to use OneFlow's global Tensor for distributed training. Global Tensor regards the computing cluster as a supercomputing device, allowing users to write distributed training code just like in a single-machine environment.

- OneFlow also provides a DDP(DistributedDataParallel) module aligned with PyTorch. DDP has been well-known and widely used in data parallelism by the majority of PyTorch users. Also see `PyTorch DDP introduction <https://pytorch.org/docs/1.10/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel>`_.



Basic
-------------------------------
When you start distributed training in OneFlow, the following functions can be used.

.. currentmodule:: oneflow.env

.. autosummary::
    :toctree: generated
    :nosignatures:

    get_world_size
    get_rank
    get_local_rank
    get_node_size
    init_rdma
    rdma_is_initialized


`Global Tensor`
--------------------------------------------------------------

Construct `Global Tensor`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
A `Global Tensor` can be created with a ``placement`` and a ``sbp``. The ``placement`` describes the physical devices of the global tensor will be allocated, and the ``sbp`` describes its distribution among these devices.

::

    >>>import oneflow as flow
    >>> # Place a global tensor on cuda device of rank(process) 0 and 1
    >>> placement = flow.placement(type="cuda", ranks=[0, 1])
    >>> # Each rank's local data is a part data as a result of spliting global data on dim 0
    >>> sbp = flow.sbp.split(dim=0)
    >>> # Create a global tensor by randn
    >>> x = flow.randn(4, 5, placement=placement, sbp=sbp)
    >>> x.shape
    oneflow.Size([4, 5])


Convert `Local Tensor` to `Global Tensor`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

With ``Tensor.to_global`` interface, `Local Tensor` can create a `Global Tensor` and use that `Local Tensor` as its local component at the current node.

Two `local tensors` with the shape of ``(2,5)`` are created separately on two devices. While after the ``to_global`` method, the `global tensor` with a shape of ``(4,5)`` is obtained.

Code running on Node 0

::

    import oneflow as flow

    x = flow.randn(2,5)
    placement = flow.placement("cuda", [0,1])
    sbp = flow.sbp.split(0)
    x_global = x.to_global(placement=placement, sbp=sbp)
    x_global.shape

Code running on Node 1

::

    import oneflow as flow

    x = flow.randn(2,5)
    placement = flow.placement("cuda", [0,1])
    sbp = flow.sbp.split(0)
    x_global = x.to_global(placement=placement, sbp=sbp)
    x_global.shape

Redistribute `Global Tensor`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Redistributing a `Global Tensor` means moving its data to another device group (or placement), or changing its data distribution (or SBP) across the group, or both at the same time. The redistributed tensor is still a `Global Tensor`.

::

    >>> import oneflow as flow
    >>> x = flow.tensor([1.0, 2.0], placement=flow.placement("cuda", ranks=[0, 1]), sbp=flow.sbp.split(0))
    >>> y = x.to_global(placement=flow.placement("cuda", ranks=[2, 3]), sbp=flow.sbp.broadcast)

According to the operator's semantics, OneFlow defines a sequence of valid input and output SBP combinations for each built-in operator. So OneFlow could automatically redistribute the `Global Tensor` to satisfy the operator's SBP requirements for its input Tensor. For example, the following code:

::

    >>> import oneflow as flow
    >>> x = flow.randn(4, 4, 
            placement=flow.placement("cuda", ranks=[0, 1]), 
            sbp=flow.sbp.split(0))
    >>> y = flow.randn(4, 4, 
            placement=flow.placement("cuda", ranks=[0, 1]), 
            sbp=flow.sbp.split(1))
    >>> z = x + y

When ``x + y`` is executed, since x is split along dimension ``0`` and y is split along dimension ``1``, their local components at each node can not be added directly, then OneFlow will automatically redistribute one of x and y to make them have the same SBP, and complete the add operation successfully.

.. note ::
    - Global Tensor can not be used in combination with DDP currently.
    - Global Tensor requires all devices to execute at the same pace, otherwise, it may cause multi-process deadlock.

Get Local Tensor from Global Tensor
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
With ``Tensor.to_local`` interface, the `Global Tensor` can return its local component at the current node.

::

    y = x.to_local()
    y.is_local
    True
    y
    tensor([[ 2.9186e-01, -3.9442e-01,  4.7072e-04, -3.2216e-01,  1.7788e-01],
                [-4.5284e-01,  1.2361e-01, -3.5962e-01,  2.6651e-01,  1.2951e+00]],
            device='cuda:0', dtype=oneflow.float32)


DistributedDataParallel
--------------------------------------------------------------

For more information about DistributedDataParallel, see ``nn.parallel.DistributedDataParallel``

The following script shows the process of using ``oneflow.nn.parallel.DistributedDataParallel`` for training data parallel: 

.. code-block:: 

    import oneflow as flow
    from oneflow.nn.parallel import DistributedDataParallel as ddp

    train_x = [
        flow.tensor([[1, 2], [2, 3]], dtype=flow.float32),
        flow.tensor([[4, 6], [3, 1]], dtype=flow.float32),
    ]
    train_y = [
        flow.tensor([[8], [13]], dtype=flow.float32),
        flow.tensor([[26], [9]], dtype=flow.float32),
    ]


    class Model(flow.nn.Module):
        def __init__(self):
            super().__init__()
            self.lr = 0.01
            self.iter_count = 500
            self.w = flow.nn.Parameter(flow.tensor([[0], [0]], dtype=flow.float32))

        def forward(self, x):
            x = flow.matmul(x, self.w)
            return x


    m = Model().to("cuda")
    m = ddp(m)
    loss = flow.nn.MSELoss(reduction="sum")
    optimizer = flow.optim.SGD(m.parameters(), m.lr)

    for i in range(0, m.iter_count):
        rank = flow.env.get_rank()
        x = train_x[rank].to("cuda")
        y = train_y[rank].to("cuda")

        y_pred = m(x)
        l = loss(y_pred, y)
        if (i + 1) % 50 == 0:
            print(f"{i+1}/{m.iter_count} loss:{l}")

        optimizer.zero_grad()
        l.backward()
        optimizer.step()

    print(f"\nw:{m.w}")

There are only two differences between the data parallelism training code and the stand-alone single-card script:

- Use `DistributedDataParallel` to wrap the module object (`m = ddp(m)`)
- Use `get_rank` to get the current device number and distribute the data to the device.

Then use `launcher` to run the script, leave everything else to OneFlow, which makes distributed training as simple as stand-alone single-card training:

::

    python3 -m oneflow.distributed.launch --nproc_per_node 2 ./ddp_train.py


Communication collectives
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. currentmodule:: oneflow.comm
    
.. autosummary::
    :toctree: generated
    :nosignatures:

        all_reduce
        all_gather
        all_gather_into_tensor
        all_to_all
        broadcast
        barrier
        gather
        reduce
        reduce_scatter
        reduce_scatter_tensor
        recv
        scatter
        send

Launching distributed training
--------------------------------------------------------------

yuguo's avatar
yuguo committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
.. currentmodule:: oneflow.distributed

run commands below to see more about usage.

::

    python3 -m oneflow.distributed.launch -h

.. code-block::

    usage: launch.py [-h] [--nnodes NNODES] [--node_rank NODE_RANK]
                 [--nproc_per_node NPROC_PER_NODE] [--master_addr MASTER_ADDR]
                 [--master_port MASTER_PORT] [-m] [--no_python]
                 [--redirect_stdout_and_stderr] [--logdir LOGDIR]
                 training_script ...

    OneFlow distributed training launch helper utility that will spawn up multiple
    distributed processes

    positional arguments:
    training_script       The full path to the single GPU training program/script to be
                            launched in parallel, followed by all the arguments for the
                            training script
    training_script_args

    optional arguments:
    -h, --help            show this help message and exit
    --nnodes NNODES       The number of nodes to use for distributed training
    --node_rank NODE_RANK
                            The rank of the node for multi-node distributed training
    --nproc_per_node NPROC_PER_NODE
                            The number of processes to launch on each node, for GPU
                            training, this is recommended to be set to the number of GPUs in
                            your system so that each process can be bound to a single GPU.
    --master_addr MASTER_ADDR
                            Master node (rank 0)'s address, should be either the IP address
                            or the hostname of node 0, for single node multi-proc training,
                            the --master_addr can simply be 127.0.0.1
    --master_port MASTER_PORT
                            Master node (rank 0)'s free port that needs to be used for
                            communication during distributed training
    -m, --module          Changes each process to interpret the launch script as a python
                            module, executing with the same behavior as'python -m'.
    --no_python           Do not prepend the training script with "python" - just exec it
                            directly. Useful when the script is not a Python script.
    --redirect_stdout_and_stderr
                            write the stdout and stderr to files 'stdout' and 'stderr'. Only
                            available when logdir is set
    --logdir LOGDIR       Relative path to write subprocess logs to. Passing in a relative
                            path will create a directory if needed. Note that successive
                            runs with the same path to write logs to will overwrite existing
                            logs, so be sure to save logs as needed.