random.py 13.9 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13


# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch

import contextlib

import torch
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from torch.utils.checkpoint import detach_variable

14
15
from megatron.memory import allocate_mem_buff

16
from .initialize import get_data_parallel_rank
17
18
19
from .initialize import get_tensor_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_world_size
20
21
22


# Default name for the model parallel rng tracker.
23
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
24
25
26
27
28
29
30
31
32
33
34


def _set_cuda_rng_state(new_state, device=-1):
    """Sets the random number generator state of the current GPU.

    Argumentss:
        new_state (torch.ByteTensor): The desired state
    This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
    with a single change: the input state is not cloned. Cloning caused
    major performance issues for +4 GPU cases.
    """
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
        # older PyTorch
        def cb():
            with device_ctx_manager(device):
                _C._cuda_setRNGState(new_state)
    else:
        # newer PyTorch
        if device == -1:
            device = torch.device('cuda')
        elif isinstance(device, str):
            device = torch.device(device)
        elif isinstance(device, int):
            device = torch.device('cuda', device)

        def cb():
            idx = device.index
            if idx is None:
                idx = torch.cuda.current_device()
            default_generator = torch.cuda.default_generators[idx]
            default_generator.set_state(new_state)
55
56
57
58

    _lazy_call(cb)


59
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
60
    """Break a tensor into equal 1D chunks."""
61
62
    partition_size = torch.numel(tensor) // \
        get_tensor_model_parallel_world_size()
63
    start_index = partition_size * get_tensor_model_parallel_rank()
64
    end_index = start_index + partition_size
65
66
67
68
69
70
71
72
73
    if new_buffer:
        data = torch.empty(partition_size, dtype=tensor.dtype,
                           device=torch.cuda.current_device(),
                           requires_grad=False)
        data.copy_(tensor.view(-1)[start_index:end_index])
    else:
        data = tensor.view(-1)[start_index:end_index]
    return data
    
74
75
76

def gather_split_1d_tensor(tensor):
    """Opposite of above function, gather values from model parallel ranks."""
77
78
    numel_gathered = torch.numel(tensor) * \
        get_tensor_model_parallel_world_size()
79
80
81
    gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
                           device=torch.cuda.current_device(),
                           requires_grad=False)
82
83
84
85
86
    # TODO: This API is experimental in pytorch (as of Feb 2022) and
    # this might break in future pytorch releases. We chose this API
    # as opposed to torch.distributed.all_gather for efficiency reasons.
    # This API calls directly NCCL all-gather versus the former does
    # internal copies and can potentially cause slow down.
87
88
    torch.distributed._all_gather_base(gathered, tensor,
                                       group=get_tensor_model_parallel_group())
89
90
    return gathered

91

92
def _kernel_make_viewless_tensor(inp, requires_grad):
Lawrence McAfee's avatar
Lawrence McAfee committed
93
94
95
96
97
98
99
100
    '''Make a viewless tensor.

    View tensors have the undesirable side-affect of retaining a reference
    to the originally-viewed tensor, even after manually setting the '.data'
    field. This method creates a new tensor that links to the old tensor's
    data, without linking the viewed tensor, referenced via the '._base'
    field.
    '''
101
102
103
104
105
106
107
108
109
    out = torch.empty(
        (1,),
        dtype = inp.dtype,
        device = inp.device,
        requires_grad = requires_grad,
    )
    out.data = inp.data
    return out

Lawrence McAfee's avatar
Lawrence McAfee committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class MakeViewlessTensor(torch.autograd.Function):
    '''
    Autograd function to make a viewless tensor.

    This function should be used in cases where the computation graph needs
    to be propagated, but we only want a viewless tensor (e.g.,
    ParallelTransformer's hidden_states). Call this function by passing
    'keep_graph = True' to 'make_viewless_tensor()'.
    '''
    @staticmethod
    def forward(ctx, inp, requires_grad):
        return _kernel_make_viewless_tensor(inp, requires_grad)
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

126
def make_viewless_tensor(inp, requires_grad, keep_graph):
Lawrence McAfee's avatar
Lawrence McAfee committed
127
128
129
130
131
132
133
134
    '''
    Entry-point for creating viewless tensors.

    This method should be used, rather than calling 'MakeViewlessTensor'
    or '_kernel_make_viewless_tensor' directly. This method acts as a
    switch for determining if an autograd function or a regular method
    should be used to create the tensor.
    '''
135
136
137
138
139
140
141
142

    # return tensor as-is, if not a 'view'
    if inp._base is None:
        return inp

    # create viewless tensor
    if keep_graph:
        return MakeViewlessTensor.apply(inp, requires_grad)
143
    else:
144
        return _kernel_make_viewless_tensor(inp, requires_grad)
145

146
def assert_viewless_tensor(tensor, extra_msg = None):
Lawrence McAfee's avatar
Lawrence McAfee committed
147
148
    '''Assert that a tensor is not a view (i.e., its '._base' field is
    not set).'''
149
150
    if isinstance(tensor, list):
        [ assert_viewless_tensor(t) for t in tensor ]
Lawrence McAfee's avatar
Lawrence McAfee committed
151
        return tensor
152
    if not isinstance(tensor, torch.Tensor):
Lawrence McAfee's avatar
Lawrence McAfee committed
153
        return tensor
154
    assert tensor._base is None, (
155
156
        "Ensure tensor._base is None before setting tensor.data or storing "
        "tensor to memory buffer. Otherwise, a memory leak will occur (and "
157
158
        "likely accumulate over iterations). %s"
    ) % extra_msg
Lawrence McAfee's avatar
Lawrence McAfee committed
159
    return tensor
160

Lawrence McAfee's avatar
Lawrence McAfee committed
161
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
Lawrence McAfee's avatar
Lawrence McAfee committed
162
163
164
165
166
    '''Safely set tensor's '.data' field.

    Check first that the tensor is viewless (i.e., '._base' not set). If not,
    raise an exception.
    '''
167
    assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
168
    tensor.data = new_data_tensor
Lawrence McAfee's avatar
Lawrence McAfee committed
169

170

171
172
173
174
175
176
177
178
class CudaRNGStatesTracker:
    """Tracker for the cuda RNG states.

    Using the `add` method, a cuda rng state is initialized based on
    the input `seed` and is assigned to `name`. Later, by forking the
    rng state, we can perform operations and return to our starting
    cuda state.
    """
Neel Kant's avatar
Neel Kant committed
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    def __init__(self):
        # Map from a string name to the cuda rng state.
        self.states_ = {}
        # Seeds are just for book keeping and ensure no seed is set twice.
        self.seeds_ = set()

    def reset(self):
        """Set to the initial state (no tracker)."""
        self.states_ = {}
        self.seeds_ = set()

    def get_states(self):
        """Get rng states. Copy the dictionary so we have direct
        pointers to the states, not just a pointer to the dictionary."""
        states = {}
        for name in self.states_:
            states[name] = self.states_[name]
        return states

    def set_states(self, states):
        """Set the rng states. For efficiency purposes, we do not check
        the size of seed for compatibility."""
        self.states_ = states

    def add(self, name, seed):
        """Track the rng state."""
        # Check seed is not already used.
        if seed in self.seeds_:
            raise Exception('seed {} already exists'.format(seed))
        self.seeds_.add(seed)
        # Check that state is not already defined.
        if name in self.states_:
            raise Exception('cuda rng state {} already exists'.format(name))
        # Get the current rng state.
        orig_rng_state = torch.cuda.get_rng_state()
        # Set the new state and store it.
        torch.cuda.manual_seed(seed)
        self.states_[name] = torch.cuda.get_rng_state()
        # Reset rng state to what it was.
        _set_cuda_rng_state(orig_rng_state)

    @contextlib.contextmanager
    def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
        """Fork the cuda rng state, perform operations, and exit with
        the original state."""
        # Check if we have added the state
        if name not in self.states_:
            raise Exception('cuda rng state {} is not added'.format(name))
        # Store current rng state.
        orig_cuda_rng_state = torch.cuda.get_rng_state()
        # Set rng state to the desired one
        _set_cuda_rng_state(self.states_[name])
        # Do the stuff we wanted to do.
        try:
            yield
        finally:
            # Update the current rng state for later use.
            self.states_[name] = torch.cuda.get_rng_state()
            # And set the state to the original state we started with.
            _set_cuda_rng_state(orig_cuda_rng_state)


# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()


def get_cuda_rng_tracker():
    """Get cuda rng tracker."""
    return _CUDA_RNG_STATE_TRACKER


251
def model_parallel_cuda_manual_seed(seed):
252
253
254
255
256
257
258
259
260
261
    """Initialize model parallel cuda seed.

    This function should be called after the model parallel is
    initialized. Also, no torch.cuda.manual_seed should be called
    after this function. Basically, this is replacement for that
    function.
    Two set of RNG states are tracked:
        default state: This is for data parallelism and is the same among a
                       set of model parallel GPUs but different across
                       different model paralle groups. This is used for
262
263
                       example for dropout in the non-tensor-model-parallel regions.
        tensor-model-parallel state: This state is different among a set of model
264
265
266
267
268
269
                              parallel GPUs, but the same across data parallel
                              groups. This is used for example for dropout in
                              model parallel regions.
    """
    # 2718 is just for fun and any POSITIVE value will work.
    offset = seed + 2718
270
    tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
Nako Sung's avatar
Nako Sung committed
271
    # Data parallel gets the original seed.
272
273
274
275
276
277
    data_parallel_seed = seed

    if torch.distributed.get_rank() == 0:
        print('> initializing model parallel cuda seeds on global rank {}, '
              'model parallel rank {}, and data parallel rank {} with '
              'model parallel seed: {} and data parallel seed: {}'.format(
278
279
                  torch.distributed.get_rank(), get_tensor_model_parallel_rank(),
                  get_data_parallel_rank(), tensor_model_parallel_seed,
280
281
282
283
284
285
                  data_parallel_seed), flush=True)
    _CUDA_RNG_STATE_TRACKER.reset()
    # Set the default state.
    torch.cuda.manual_seed(data_parallel_seed)
    # and model parallel state.
    _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
286
                                tensor_model_parallel_seed)
287
288
289
290
291
292
293
294
295
296


class CheckpointFunction(torch.autograd.Function):
    """This function is adapted from torch.utils.checkpoint with
       two main changes:
           1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
           2) the states in the model parallel tracker are also properly
              tracked/set/reset.
    """
    @staticmethod
Vijay Korthikanti's avatar
Vijay Korthikanti committed
297
    def forward(ctx, run_function, distribute_saved_activations, *args):
298
        ctx.run_function = run_function
Vijay Korthikanti's avatar
Vijay Korthikanti committed
299
300
        ctx.distribute_saved_activations \
            = distribute_saved_activations
301
302
303
304
305
306
307
308

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        with torch.no_grad():
            outputs = run_function(*args)
309
310
311

        # Divide hidden states across model parallel group and only keep
        # the chunk corresponding to the current rank.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
312
        if distribute_saved_activations:
313
            ctx.input_0_shape = args[0].data.shape
Lawrence McAfee's avatar
Lawrence McAfee committed
314
            safely_set_viewless_tensor_data(
315
316
                args[0],
                split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
317

318
319
320
        # Store everything.
        ctx.save_for_backward(*args)

321
322
323
324
325
326
327
328
        return outputs

    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), "
                               "please use .backward() if possible")
        inputs = ctx.saved_tensors
Vijay Korthikanti's avatar
Vijay Korthikanti committed
329
        if ctx.distribute_saved_activations:
Lawrence McAfee's avatar
Lawrence McAfee committed
330
            safely_set_viewless_tensor_data(
331
                inputs[0],
Lawrence McAfee's avatar
Lawrence McAfee committed
332
                gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
        bwd_cuda_rng_state = torch.cuda.get_rng_state()
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

        # Compute the forward pass.
        detached_inputs = detach_variable(inputs)
        with torch.enable_grad():
            outputs = ctx.run_function(*detached_inputs)

        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
        _set_cuda_rng_state(bwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        torch.autograd.backward(outputs, args)
357
358
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in detached_inputs)
359
        return (None, None) + grads
360
361


Vijay Korthikanti's avatar
Vijay Korthikanti committed
362
def checkpoint(function, distribute_saved_activations, *args):
363
364
    """Checkpoint a model or part of the model.
    This has been directly copied from torch.utils.checkpoint."""
mshoeybi's avatar
mshoeybi committed
365
    return CheckpointFunction.apply(function,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
366
                                    distribute_saved_activations, *args)