memory_efficient.py 6.73 KB
Newer Older
1
# Copyright (c) OpenMMLab. All rights reserved.
pppppM's avatar
pppppM committed
2
3
4
import inspect
import re
import warnings
5
from contextlib import contextmanager
pppppM's avatar
pppppM committed
6
7
from functools import partial
from typing import List
8
9
10
11

import torch
from torch import nn

pppppM's avatar
pppppM committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from lmdeploy.lite.defaults import KV_CACHE_SIGNATURE, OFFLOAD_MOD


def extract_return_values(module: nn.Module) -> List[str]:
    """Extracts return values from given module's forward method.

    Args:
        module (nn.Module): Module to inspect

    Returns:
        list[str]: List of return values
    """

    last_line = inspect.getsource(module.forward).rstrip('\n').split('\n')[-1]
    pattern = r'return ([\w\s,]+)'
    match = re.search(pattern, last_line)

    if match:
        return_values = match.group(1).split(',')
        return [value.strip() for value in return_values]
    else:
        return []


def find_kv_cache_idx(module: nn.Module) -> int:
    """Finds index of kv cache signature in module's forward parameters."""

    signatures = list(inspect.signature(module.forward).parameters.keys())
    if KV_CACHE_SIGNATURE not in signatures:
        raise ValueError(f'{KV_CACHE_SIGNATURE} not in signatures of '
                         f'{type(module)} forward.')
    return signatures.index(KV_CACHE_SIGNATURE)


def find_modules_by_return_value(model: nn.Module,
                                 value: str) -> List[nn.Module]:
    """Finds modules in model that return given value.

    Args:
        model (nn.Module): Model to inspect
        value (str): Return value to search for

    Returns:
        list[nn.Module]: List of matching modules

    Raises:
        ValueError: If no matching modules found
    """

    modules = []
    for name, module in model.named_modules():
        returns = extract_return_values(module)
        if value in returns:
            print(f'Found {name} returning {value}')
            modules.append(module)

    if not modules:
        error_msg = f'No modules found returning {value}. '
        error_msg += 'Please check if the default KV_CACHE_SIGNATURE  '
        error_msg += f"'{KV_CACHE_SIGNATURE}' matches what is used in your "
        error_msg += 'model code. If not, you can modify KV_CACHE_SIGNATURE '
        error_msg += 'in `lmdeploy.lite.defaults`.'
        raise ValueError(error_msg)

    return modules

78
79

@contextmanager
pppppM's avatar
pppppM committed
80
81
def offload_kv_cache(model: nn.Module, device: str = 'cuda') -> None:
    """Offloads kv cache to given device during forward pass.
82
83

    Args:
pppppM's avatar
pppppM committed
84
85
        model (nn.Module): Model for inference
        device (str): Device to offload to
86
87

    Yields:
pppppM's avatar
pppppM committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        None
    """

    modules = find_modules_by_return_value(model, KV_CACHE_SIGNATURE)

    original_forwards = {mod: mod.forward for mod in modules}
    input_idxs = {mod: find_kv_cache_idx(mod) for mod in modules}
    output_idxs = {
        mod: extract_return_values(mod).index(KV_CACHE_SIGNATURE)
        for mod in modules
    }

    def wrap_forward(module, *args, **kwargs):

        idx = input_idxs[module]
        if idx >= len(args):
            # kv cache in kwargs
            if KV_CACHE_SIGNATURE in kwargs:
                if kwargs[KV_CACHE_SIGNATURE]:
                    kwargs[KV_CACHE_SIGNATURE] = kwargs[KV_CACHE_SIGNATURE].to(
                        device)
            else:
                raise ValueError(f'No kv cache input found at index {idx}')
        else:
            # kv cache in args
            args = list(args)
            args[idx] = args[idx].to(device)
            args = tuple(args)

        result = original_forwards[module](*args, **kwargs)
118

pppppM's avatar
pppppM committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        result = list(result)
        idx = output_idxs[module]

        # Move kv cache outputs back to CPU
        key = result[idx][0].to('cpu')
        value = result[idx][1].to('cpu')
        torch.cuda.empty_cache()

        result[idx] = (key, value)
        result = tuple(result)

        return result

    try:
        for module in modules:
            original_forwards[module] = module.forward
            module.forward = partial(wrap_forward, module)

        yield

    finally:
        for module in modules:
            module.forward = original_forwards[module]
            del original_forwards[module]


@contextmanager
def offload_weights(model: nn.Module, device: str = 'cuda') -> None:
    """Offloads specified modules to given device during forward pass.

    Args:
        model (nn.Module): Model for inference
        device (str): Device to offload to

    Yields:
        None
155
156
    """

pppppM's avatar
pppppM committed
157
    target_modules = OFFLOAD_MOD
158

pppppM's avatar
pppppM committed
159
160
161
162
163
    def before_forward(module: nn.Module, inp: torch.Tensor):
        module.to(device)

    def after_forward(module: nn.Module, inp: torch.Tensor, out: torch.Tensor):
        module.to('cpu')
164
165
166
        torch.cuda.empty_cache()

    def _to_device(m, spec_modules, dev):
pppppM's avatar
pppppM committed
167
        if len(spec_modules) == 0 or len(list(m.children())) == 0:
168
169
170
171
172
173
174
175
            m.to(dev)
            return

        for child in m.children():
            if isinstance(child, spec_modules):
                child.to('cpu')
            else:
                _to_device(child, spec_modules, dev)
pppppM's avatar
pppppM committed
176
177
178
179
180
181
                # m.to(dev)

    warnings.warn('By default, offloading will be done on '
                  '`nn.Linear`. You can add modules which want offload to '
                  'the `lmdeploy.lite.defaults.OFFLOAD_MOD`.')
    target = OFFLOAD_MOD
182
183

    _to_device(model, target, device)
pppppM's avatar
pppppM committed
184
185

    handles = []
186
    for module in model.modules():
pppppM's avatar
pppppM committed
187
188
189
190
        if isinstance(module, target_modules):
            handle1 = module.register_forward_pre_hook(before_forward)
            handle2 = module.register_forward_hook(after_forward)
            handles.extend([handle1, handle2])
191

pppppM's avatar
pppppM committed
192
    try:
193
        yield
pppppM's avatar
pppppM committed
194
195
196
197
198
199
200
201
202
203
204
205
206
    finally:
        for handle in handles:
            handle.remove()

        model.to('cpu')
        torch.cuda.empty_cache()


@contextmanager
def memory_efficient_inference(model: nn.Module,
                               offload: bool = True,
                               device: str = 'cuda') -> None:
    """Memory efficient inference context manager.
207

pppppM's avatar
pppppM committed
208
209
210
211
212
213
214
215
216
217
218
    Moves model to device for inference, with option to offload
    specific modules.

    Args:
        model (nn.Module): Model for inference
        offload (bool): Whether to offload modules
        device (str): Device for inference

    Yields:
        None
    """
219

pppppM's avatar
pppppM committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    if offload:
        warnings.warn('Using offload mode - modules defined in OFFLOAD_MOD '
                      'will be moved to GPU during forward pass only.')
        warnings.warn(
            'Using offload mode will incur performance penalty due to '
            'frequent CPU-GPU data transfers.')
        with torch.inference_mode():
            with offload_kv_cache(model, device):
                with offload_weights(model, device):
                    yield
    else:
        model.to(device)
        with torch.inference_mode():
            yield