distributed.py 17.6 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
#
# See LICENSE for license information.

"""Methods needed for distributed training (DP/TP)."""
6
7
8
from contextlib import contextmanager
from typing import Any, Dict, Union, Optional, Callable, Tuple

Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
12
13
14
import torch
from torch.cuda import _lazy_call
from torch.utils.checkpoint import detach_variable

from .utils import safely_set_viewless_tensor_data
from .constants import dist_group_type
15
from .fp8 import FP8GlobalStateManager
Przemek Tredak's avatar
Przemek Tredak committed
16

17
18
19
20

__all__ = ["checkpoint", "CudaRNGStatesTracker"]


Przemek Tredak's avatar
Przemek Tredak committed
21
22
23
24
25
26
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
    "tensor_model_parallel": False,
    "partition_dim": -1,
    "partition_stride": 1,
}

27
28
29
_FP8_ACTIVATION_RECOMPUTE_ENABLED = False
_FP8_ACTIVATION_RECOMPUTE_PHASE = False

Przemek Tredak's avatar
Przemek Tredak committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70

def _set_cuda_rng_state(new_state: torch.Tensor, device: Union[int, str] = -1) -> None:
    """Sets the random number generator state of the current GPU.

    Arguments:
        new_state (torch.ByteTensor): The desired state
    This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
    with a single change: the input state is not cloned. Cloning caused
    major performance issues for +4 GPU cases.
    """
    if device == -1:
        device = torch.device("cuda")
    elif isinstance(device, str):
        device = torch.device(device)
    elif isinstance(device, int):
        device = torch.device("cuda", device)

    def cb() -> None:
        idx = device.index
        if idx is None:
            idx = torch.cuda.current_device()
        default_generator = torch.cuda.default_generators[idx]
        default_generator.set_state(new_state)

    _lazy_call(cb)


def set_tensor_model_parallel_attributes(
    tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int
) -> None:
    """set attributes needed for TP"""
    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
        assert not hasattr(tensor, attribute)
    # Set the attributes.
    setattr(tensor, "tensor_model_parallel", is_parallel)
    setattr(tensor, "partition_dim", dim)
    setattr(tensor, "partition_stride", stride)


def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int:
    """Return world size for the distributed group."""
71
    if not torch.distributed.is_initialized():
Przemek Tredak's avatar
Przemek Tredak committed
72
73
74
75
76
77
        return 1
    return torch.distributed.get_world_size(group=group)


def get_distributed_rank(group: Optional[dist_group_type] = None) -> int:
    """Return my rank for the distributed group."""
78
    assert torch.distributed.is_initialized(), "torch.distributed is not initialized."
Przemek Tredak's avatar
Przemek Tredak committed
79
80
81
82
83
84
85
    return torch.distributed.get_rank(group=group)


def initialize_affine_weight_gpu(
    weight: torch.Tensor,
    init_method: Callable,
    get_rng_state_tracker: Callable,
86
    partition_dim: int = 0,
Przemek Tredak's avatar
Przemek Tredak committed
87
    stride: int = 1,
88
    set_tp_attributes: bool = True,
Przemek Tredak's avatar
Przemek Tredak committed
89
90
91
) -> None:
    """Initialize affine weight for model parallel on GPU."""

92
93
94
95
    if set_tp_attributes:
        set_tensor_model_parallel_attributes(
            tensor=weight, is_parallel=True, dim=partition_dim, stride=stride
        )
Przemek Tredak's avatar
Przemek Tredak committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

    if get_rng_state_tracker is None:
        init_method(weight)
        return

    with get_rng_state_tracker().fork():
        init_method(weight)


def split_tensor_into_1d_equal_chunks(
    tensor: torch.Tensor, tp_group: dist_group_type, new_buffer: bool = False
) -> torch.Tensor:
    """Break a tensor into equal 1D chunks."""
    partition_size = torch.numel(tensor) // get_distributed_world_size(tp_group)
    start_index = partition_size * get_distributed_rank(tp_group)
    end_index = start_index + partition_size
    if new_buffer:
        data = torch.empty(
            partition_size,
            dtype=tensor.dtype,
            device=torch.cuda.current_device(),
            requires_grad=False,
        )
        data.copy_(tensor.view(-1)[start_index:end_index])
    else:
        data = tensor.view(-1)[start_index:end_index]
    return data


def gather_split_1d_tensor(
    tensor: torch.Tensor, tp_group: dist_group_type
) -> torch.Tensor:
    """Opposite of above function, gather values from model parallel ranks."""
    numel_gathered = torch.numel(tensor) * get_distributed_world_size(tp_group)
    gathered = torch.empty(
        numel_gathered,
        dtype=tensor.dtype,
        device=torch.cuda.current_device(),
        requires_grad=False,
    )
136
    torch.distributed.all_gather_into_tensor(gathered, tensor, group=tp_group)
Przemek Tredak's avatar
Przemek Tredak committed
137
138
139
    return gathered


140
141
142
143
144
145
146
147
148
149
150
151
152
153
@contextmanager
def activation_recompute_forward(
    activation_recompute: bool = False,
    recompute_phase: bool = False,
) -> None:
    """Context manager used to control the forward runtime behavior when executed
    under the `CheckpointFunction` function. For running FP8, the forward pass will
    run without storing intermediate activations. Instead, the forward pass saves
    the inputs tuple and the calling function. In the backwards pass, these are
    retrieved, and the forward pass is computed again while tracking the intermediate
    activations, followed by calculation of gradients using these values.
    """
    global _FP8_ACTIVATION_RECOMPUTE_ENABLED, _FP8_ACTIVATION_RECOMPUTE_PHASE
    try:
154
155
        _FP8_ACTIVATION_RECOMPUTE_ENABLED = (
            activation_recompute and FP8GlobalStateManager.is_fp8_enabled())
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
        _FP8_ACTIVATION_RECOMPUTE_PHASE = recompute_phase
        yield
    finally:
        _FP8_ACTIVATION_RECOMPUTE_ENABLED = False
        _FP8_ACTIVATION_RECOMPUTE_PHASE = False


def is_fp8_activation_recompute_enabled() -> bool:
    """Return global boolean"""
    return _FP8_ACTIVATION_RECOMPUTE_ENABLED


def in_fp8_activation_recompute_phase() -> bool:
    """Return global boolean"""
    return _FP8_ACTIVATION_RECOMPUTE_PHASE


Przemek Tredak's avatar
Przemek Tredak committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
class CheckpointFunction(torch.autograd.Function):
    """This function is adapted from torch.utils.checkpoint with
    two main changes:
        1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
        2) the states in the model parallel tracker are also properly
           tracked/set/reset.
    """

    @staticmethod
    def forward(
        ctx,
        run_function: Callable,
        distribute_saved_activations: bool,
        get_cuda_rng_tracker: Callable,
        tp_group: dist_group_type,
188
        kwargs: Dict[str, Any],
Przemek Tredak's avatar
Przemek Tredak committed
189
190
        *args: Tuple[torch.Tensor, ...],
    ) -> Tuple[torch.Tensor, ...]:
191
192
        """Call forward function while saving state to be able to
        redo the computation later."""
Przemek Tredak's avatar
Przemek Tredak committed
193
194
195
196
197
198
199
200
201
        ctx.run_function = run_function
        ctx.distribute_saved_activations = distribute_saved_activations

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        with torch.no_grad():
202
203
204
205
            with activation_recompute_forward(
                activation_recompute=True, recompute_phase=False
            ):
                outputs = run_function(*args, **kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221

        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
        if distribute_saved_activations:
            ctx.input_0_shape = args[0].data.shape
            safely_set_viewless_tensor_data(
                args[0],
                split_tensor_into_1d_equal_chunks(
                    args[0].data, tp_group, new_buffer=True
                ),
            )

        # Store everything.
        ctx.save_for_backward(*args)
        ctx.get_cuda_rng_tracker = get_cuda_rng_tracker
        ctx.tp_group = tp_group
222
        ctx.kwargs = kwargs
Przemek Tredak's avatar
Przemek Tredak committed
223
224
225
226
227

        return outputs

    @staticmethod
    def backward(
228
        ctx, *args: Tuple[Union[torch.Tensor, None], ...]
Przemek Tredak's avatar
Przemek Tredak committed
229
    ) -> Tuple[Union[torch.Tensor, None], ...]:
230
        """Call backward function with activation recomputation."""
Przemek Tredak's avatar
Przemek Tredak committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad(), "
                "please use .backward() if possible"
            )
        inputs = ctx.saved_tensors
        get_cuda_rng_tracker = ctx.get_cuda_rng_tracker

        if ctx.distribute_saved_activations:
            safely_set_viewless_tensor_data(
                inputs[0],
                gather_split_1d_tensor(inputs[0].data, ctx.tp_group).view(
                    ctx.input_0_shape
                ),
            )

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
        bwd_cuda_rng_state = torch.cuda.get_rng_state()
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

        # Compute the forward pass.
        detached_inputs = detach_variable(inputs)
        with torch.enable_grad():
260
261
262
263
            with activation_recompute_forward(
                activation_recompute=True, recompute_phase=True
            ):
                outputs = ctx.run_function(*detached_inputs, **ctx.kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
264
265
266
267
268
269
270
271
272
273
274
275
276

        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
        _set_cuda_rng_state(bwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
        grads = tuple(
            inp.grad if isinstance(inp, torch.Tensor) else inp
            for inp in detached_inputs
        )
277
        return (None, None, None, None, None) + grads
Przemek Tredak's avatar
Przemek Tredak committed
278
279
280
281
282
283
284
285


def checkpoint(
    function: Callable,
    distribute_saved_activations: bool,
    get_cuda_rng_tracker: Callable,
    tp_group: dist_group_type,
    *args: Tuple[torch.Tensor, ...],
286
    **kwargs: Dict[str, Any],
Przemek Tredak's avatar
Przemek Tredak committed
287
) -> Tuple[torch.Tensor, ...]:
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
    """
    Checkpoint a part of the model by trading compute for memory. This function is based on
    `torch.utils.checkpoint.checkpoint <https://pytorch.org/docs/stable/checkpoint.html>`_.

    .. warning::

        It is the user's responsibility to ensure identical behavior when calling
        :attr:`function` from the forward and backward pass. If different output is
        produced (e.g. due to global state), then the checkpointed version won't
        be numerically equivalent.

    .. warning::

        The tuple :attr:`args` must contain only tensors (or :attr:`None`) in order to comply with
        PyTorch's :attr:`save_for_backward` method. :attr:`function` must be callable to produce
        valid outputs with the inputs :attr:`args` and :attr:`kwargs`.

    Parameters
    ----------
    function: Callable
308
309
            pytorch module used to run the forward and backward passes using
            the specified :attr:`args` and :attr:`kwargs`.
310
311
312
313
314
    distribute_saved_activations: bool
            if set to `True`, the first tensor argument is distributed across the
            specified tensor parallel group (`tp_group`) before saving it for the
            backward pass.
    get_cuda_rng_tracker: `Callable`
315
            python callable which returns an instance of :func:`CudaRNGStatesTracker`.
316
317
318
319
320
321
322
323
    tp_group : ProcessGroup, default = `None`
            tensor parallel process group.
    args : tuple
            tuple of torch tensors for inputs to :attr:`function`.
    kwargs : dict
            dictionary of string keys for keyword arguments to :attr:`function`.
    """

Przemek Tredak's avatar
Przemek Tredak committed
324
    return CheckpointFunction.apply(
325
326
327
328
329
330
        function,
        distribute_saved_activations,
        get_cuda_rng_tracker,
        tp_group,
        kwargs,
        *args,
Przemek Tredak's avatar
Przemek Tredak committed
331
332
333
    )


334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
class CudaRNGStatesTracker:
    """
    For model parallelism, multiple RNG states need to simultaneously exist in order
    to execute operations in or out of the model parallel region. This class keeps
    track of the various RNG states and provides utility methods to maintain them and
    execute parts of the model under a given RNG setting. Using the `add` method, a
    cuda rng state is initialized based on the input `seed` and is assigned to `name`.
    Later, by forking the rng state, we can perform operations and return to our starting
    cuda state.
    """

    def __init__(self):
        # Map from a string name to the cuda rng state.
        self.states_ = {}
        # Seeds are just for book keeping and ensure no seed is set twice.
        self.seeds_ = set()

    def reset(self):
        """
        Set to the initial state (no tracker).
        """
        self.states_ = {}
        self.seeds_ = set()

    def get_states(self) -> Dict[str, torch.Tensor]:
        """
        Get rng states. Copy the dictionary so we have direct pointers
        to the states, not just a pointer to the dictionary.
        """
        states = {}
        for name in self.states_:
            states[name] = self.states_[name]
        return states

    def set_states(self, states: Dict[str, torch.Tensor]) -> None:
        """
        Set the rng states. For efficiency purposes, we do not
        check the size of seed for compatibility.

        states: Dict[str, torch.Tensor]
               A mapping from string names to RNG states.
        """
        self.states_ = states

    def add(self, name: str, seed: int) -> None:
        """
        Adds a new RNG state.

        name: str
             string identifier for the RNG state.
        seed: int
             PyTorch seed for the RNG state.
        """
        # Check seed is not already used.
        if seed in self.seeds_:
            raise Exception(f"seed {seed} already exists")
        self.seeds_.add(seed)
        # Check that state is not already defined.
        if name in self.states_:
            raise Exception(f"cuda rng state {name} already exists")
        # Get the current rng state.
        orig_rng_state = torch.cuda.get_rng_state()
        # Set the new state and store it.
        torch.cuda.manual_seed(seed)
        self.states_[name] = torch.cuda.get_rng_state()
        # Reset rng state to what it was.
        _set_cuda_rng_state(orig_rng_state)

    @contextmanager
    def fork(self, name: str = "model-parallel-rng") -> None:
        """
        Fork the cuda rng state, perform operations, and exit with
        the original state.

        name: str
             string identifier for the RNG state.
        """
        # Check if we have added the state
        if name not in self.states_:
            raise Exception(f"cuda rng state {name} is not added")
        # Store current rng state.
        orig_cuda_rng_state = torch.cuda.get_rng_state()
        # Set rng state to the desired one
        _set_cuda_rng_state(self.states_[name])
        # Do the stuff we wanted to do.
        try:
            yield
        finally:
            # Update the current rng state for later use.
            self.states_[name] = torch.cuda.get_rng_state()
            # And set the state to the original state we started with.
            _set_cuda_rng_state(orig_cuda_rng_state)


Przemek Tredak's avatar
Przemek Tredak committed
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def reduce_scatter_along_first_dim(
    input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Reduce-scatter the input tensor across model parallel group."""
    world_size = get_distributed_world_size(tp_group)
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_, None

    dim_size = list(input_.size())
    assert (
        dim_size[0] % world_size == 0
    ), "First dimension of the tensor should be divisible by tensor parallel size"

    dim_size[0] = dim_size[0] // world_size

    output = torch.empty(
        dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
    )
447
    handle = torch.distributed.reduce_scatter_tensor(
Przemek Tredak's avatar
Przemek Tredak committed
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
        output, input_.contiguous(), group=tp_group, async_op=async_op
    )
    return output, handle


def gather_along_first_dim(
    input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Gather tensors and concatinate along the first dimension."""

    world_size = get_distributed_world_size(tp_group)
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_, None

    dim_size = list(input_.size())
    dim_size[0] = dim_size[0] * world_size

    output = torch.empty(
        dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
    )
469
    handle = torch.distributed.all_gather_into_tensor(
Przemek Tredak's avatar
Przemek Tredak committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        output, input_.contiguous(), group=tp_group, async_op=async_op
    )

    return output, handle


def gather_along_last_dim(
    input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Gather tensors and concatinate along the last dimension."""

    world_size = get_distributed_world_size(tp_group)
    # Bypass the function if we are using only 1 GPU.
    if world_size == 1:
        return input_, None

    dim_size = list(input_.size())
    dim_size[-1] = dim_size[-1] * world_size

    output = torch.empty(
        dim_size, dtype=input_.dtype, device=torch.cuda.current_device()
    )
492
    handle = torch.distributed.all_gather_into_tensor(
Przemek Tredak's avatar
Przemek Tredak committed
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
        output, input_.contiguous(), group=tp_group, async_op=async_op
    )

    return output, handle


def allreduce(
    input_: torch.Tensor,
    tp_group: Optional[dist_group_type] = None,
    async_op: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """All-reduce the input tensor across model parallel group."""

    # Bypass the function if we are using only 1 GPU.
    if get_distributed_world_size(tp_group) == 1:
        return input_, None

    # All-reduce.
    handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op)

    return input_, handle