utils.py 6.46 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
2
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

3
"""Utility functions used throughout Megatron core"""
liangjing's avatar
v1  
liangjing committed
4
import math
5
import operator
liangjing's avatar
v1  
liangjing committed
6
from functools import reduce
7

8
9
10
import torch

from megatron.core import parallel_state
liangjing's avatar
v1  
liangjing committed
11
from megatron.core.dist_checkpointing.mapping import ShardedTensor
12
13
14
15


def ensure_divisibility(numerator, denominator):
    """Ensure that numerator is divisible by the denominator."""
liangjing's avatar
v1  
liangjing committed
16
    assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
17
18
19
20
21
22
23
24


def divide(numerator, denominator):
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
    ensure_divisibility(numerator, denominator)
    return numerator // denominator

liangjing's avatar
v1  
liangjing committed
25
26

def get_attr_wrapped_model(model, attr, allow_none=True):
27
28
29
30
    """Get an attribute from a wrapped model"""
    if isinstance(model, list):
        raise RuntimeError("_get_attr_wrapped_model given a list of models")

liangjing's avatar
v1  
liangjing committed
31
32
33
34
35
36
37
38
39
40
41
    if allow_none:

        def condition(model, attr):
            return not hasattr(model, attr)

    else:

        def condition(model, attr):
            return getattr(model, attr, None) is None

    while condition(model, attr):
42
43
44
45
46
47
        if not hasattr(model, "module"):
            raise RuntimeError(f"_get_attr_wrapped_model couldn't find attribute {attr}")

        model = model.module
    return getattr(model, attr)

liangjing's avatar
v1  
liangjing committed
48

49
50
51
def get_model_type(model):
    return get_attr_wrapped_model(model, 'model_type')

52

liangjing's avatar
v1  
liangjing committed
53
54
55
56
def get_model_config(model):
    return get_attr_wrapped_model(model, 'config', allow_none=False)


57
58
59
60
61
62
63
64
65
66
class GlobalMemoryBuffer:
    """Global buffer to avoid dynamic memory allocations.
    Caller should ensure that buffers of the same name
    are not used concurrently."""

    def __init__(self):
        self.buffer = {}

    def get_tensor(self, tensor_shape, dtype, name):
        required_len = reduce(operator.mul, tensor_shape, 1)
liangjing's avatar
v1  
liangjing committed
67
68
69
70
71
72
73
        if (
            self.buffer.get((name, dtype), None) is None
            or self.buffer[(name, dtype)].numel() < required_len
        ):
            self.buffer[(name, dtype)] = torch.empty(
                required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
            )
74
75
76

        return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)

liangjing's avatar
v1  
liangjing committed
77

78
79
80
81
82
83
84
85
86
def _kernel_make_viewless_tensor(inp, requires_grad):
    '''Make a viewless tensor.

    View tensors have the undesirable side-affect of retaining a reference
    to the originally-viewed tensor, even after manually setting the '.data'
    field. This method creates a new tensor that links to the old tensor's
    data, without linking the viewed tensor, referenced via the '._base'
    field.
    '''
liangjing's avatar
v1  
liangjing committed
87
    out = torch.empty((1,), dtype=inp.dtype, device=inp.device, requires_grad=requires_grad,)
88
89
90
    out.data = inp.data
    return out

liangjing's avatar
v1  
liangjing committed
91

92
93
94
95
96
97
98
99
100
class MakeViewlessTensor(torch.autograd.Function):
    '''
    Autograd function to make a viewless tensor.

    This function should be used in cases where the computation graph needs
    to be propagated, but we only want a viewless tensor (e.g.,
    ParallelTransformer's hidden_states). Call this function by passing
    'keep_graph = True' to 'make_viewless_tensor()'.
    '''
liangjing's avatar
v1  
liangjing committed
101

102
103
104
    @staticmethod
    def forward(ctx, inp, requires_grad):
        return _kernel_make_viewless_tensor(inp, requires_grad)
liangjing's avatar
v1  
liangjing committed
105

106
107
108
109
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None

liangjing's avatar
v1  
liangjing committed
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def make_viewless_tensor(inp, requires_grad, keep_graph):
    '''
    Entry-point for creating viewless tensors.

    This method should be used, rather than calling 'MakeViewlessTensor'
    or '_kernel_make_viewless_tensor' directly. This method acts as a
    switch for determining if an autograd function or a regular method
    should be used to create the tensor.
    '''

    # return tensor as-is, if not a 'view'
    if inp._base is None:
        return inp

    # create viewless tensor
    if keep_graph:
        return MakeViewlessTensor.apply(inp, requires_grad)
    else:
        return _kernel_make_viewless_tensor(inp, requires_grad)

liangjing's avatar
v1  
liangjing committed
131
132

def assert_viewless_tensor(tensor, extra_msg=None):
133
134
135
    '''Assert that a tensor is not a view (i.e., its '._base' field is
    not set).'''
    if isinstance(tensor, list):
liangjing's avatar
v1  
liangjing committed
136
        [assert_viewless_tensor(t) for t in tensor]
137
138
139
140
141
142
143
144
145
146
        return tensor
    if not isinstance(tensor, torch.Tensor):
        return tensor
    assert tensor._base is None, (
        "Ensure tensor._base is None before setting tensor.data or storing "
        "tensor to memory buffer. Otherwise, a memory leak will occur (and "
        "likely accumulate over iterations). %s"
    ) % extra_msg
    return tensor

liangjing's avatar
v1  
liangjing committed
147

148
149
150
151
152
153
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
    '''Safely set tensor's '.data' field.

    Check first that the tensor is viewless (i.e., '._base' not set). If not,
    raise an exception.
    '''
liangjing's avatar
v1  
liangjing committed
154
155
156
157
158
    assert_viewless_tensor(
        tensor,
        extra_msg="FYI, tensor._base has shape %s, and new_data_tensor has shape %s."
        % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape),
    )
159
    tensor.data = new_data_tensor
liangjing's avatar
v1  
liangjing committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207


def init_method_normal(sigma):
    """Init method based on N(0, sigma)."""

    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)

    return init_


def scaled_init_method_normal(sigma, num_layers):
    """Init method based on N(0, sigma/sqrt(2*num_layers)."""
    std = sigma / math.sqrt(2.0 * num_layers)

    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_


def make_tp_sharded_tensor_for_checkpoint(tensor, key, tp_axis=0, replica_id=None, **kwargs):
    """ Helper for instantiating a ShardedTensor where the `tp_axis` dimension is sharded across TP group. """

    return ShardedTensor.from_rank_offsets(
        key,
        tensor,
        (
            tp_axis,
            parallel_state.get_tensor_model_parallel_rank(),
            parallel_state.get_tensor_model_parallel_world_size(),
        ),
        replica_id=parallel_state.get_data_parallel_rank() if replica_id is None else replica_id,
        **kwargs,
    )


def make_sharded_tensor_for_checkpoint(tensor, key, **kwargs):
    """ Helper for instantiating a non-sharded ShardedTensor (replicated across TP and DP group). """

    return ShardedTensor.from_rank_offsets(
        key,
        tensor,
        replica_id=parallel_state.get_data_parallel_rank()
        * parallel_state.get_data_parallel_world_size()
        + parallel_state.get_tensor_model_parallel_rank(),
        **kwargs,
    )