# Copyright (c) OpenMMLab. All rights reserved. import inspect import re import warnings from contextlib import contextmanager from functools import partial from typing import List import torch from torch import nn from lmdeploy.lite.defaults import KV_CACHE_SIGNATURE, OFFLOAD_MOD def extract_return_values(module: nn.Module) -> List[str]: """Extracts return values from given module's forward method. Args: module (nn.Module): Module to inspect Returns: list[str]: List of return values """ last_line = inspect.getsource(module.forward).rstrip('\n').split('\n')[-1] pattern = r'return ([\w\s,]+)' match = re.search(pattern, last_line) if match: return_values = match.group(1).split(',') return [value.strip() for value in return_values] else: return [] def find_kv_cache_idx(module: nn.Module) -> int: """Finds index of kv cache signature in module's forward parameters.""" signatures = list(inspect.signature(module.forward).parameters.keys()) if KV_CACHE_SIGNATURE not in signatures: raise ValueError(f'{KV_CACHE_SIGNATURE} not in signatures of ' f'{type(module)} forward.') return signatures.index(KV_CACHE_SIGNATURE) def find_modules_by_return_value(model: nn.Module, value: str) -> List[nn.Module]: """Finds modules in model that return given value. Args: model (nn.Module): Model to inspect value (str): Return value to search for Returns: list[nn.Module]: List of matching modules Raises: ValueError: If no matching modules found """ modules = [] for name, module in model.named_modules(): returns = extract_return_values(module) if value in returns: print(f'Found {name} returning {value}') modules.append(module) if not modules: error_msg = f'No modules found returning {value}. ' error_msg += 'Please check if the default KV_CACHE_SIGNATURE ' error_msg += f"'{KV_CACHE_SIGNATURE}' matches what is used in your " error_msg += 'model code. If not, you can modify KV_CACHE_SIGNATURE ' error_msg += 'in `lmdeploy.lite.defaults`.' raise ValueError(error_msg) return modules @contextmanager def offload_kv_cache(model: nn.Module, device: str = 'cuda') -> None: """Offloads kv cache to given device during forward pass. Args: model (nn.Module): Model for inference device (str): Device to offload to Yields: None """ modules = find_modules_by_return_value(model, KV_CACHE_SIGNATURE) original_forwards = {mod: mod.forward for mod in modules} input_idxs = {mod: find_kv_cache_idx(mod) for mod in modules} output_idxs = { mod: extract_return_values(mod).index(KV_CACHE_SIGNATURE) for mod in modules } def wrap_forward(module, *args, **kwargs): idx = input_idxs[module] if idx >= len(args): # kv cache in kwargs if KV_CACHE_SIGNATURE in kwargs: if kwargs[KV_CACHE_SIGNATURE]: kwargs[KV_CACHE_SIGNATURE] = kwargs[KV_CACHE_SIGNATURE].to( device) else: raise ValueError(f'No kv cache input found at index {idx}') else: # kv cache in args args = list(args) args[idx] = args[idx].to(device) args = tuple(args) result = original_forwards[module](*args, **kwargs) result = list(result) idx = output_idxs[module] # Move kv cache outputs back to CPU key = result[idx][0].to('cpu') value = result[idx][1].to('cpu') torch.cuda.empty_cache() result[idx] = (key, value) result = tuple(result) return result try: for module in modules: original_forwards[module] = module.forward module.forward = partial(wrap_forward, module) yield finally: for module in modules: module.forward = original_forwards[module] del original_forwards[module] @contextmanager def offload_weights(model: nn.Module, device: str = 'cuda') -> None: """Offloads specified modules to given device during forward pass. Args: model (nn.Module): Model for inference device (str): Device to offload to Yields: None """ target_modules = OFFLOAD_MOD def before_forward(module: nn.Module, inp: torch.Tensor): module.to(device) def after_forward(module: nn.Module, inp: torch.Tensor, out: torch.Tensor): module.to('cpu') torch.cuda.empty_cache() def _to_device(m, spec_modules, dev): if len(spec_modules) == 0 or len(list(m.children())) == 0: m.to(dev) return for child in m.children(): if isinstance(child, spec_modules): child.to('cpu') else: _to_device(child, spec_modules, dev) # m.to(dev) warnings.warn('By default, offloading will be done on ' '`nn.Linear`. You can add modules which want offload to ' 'the `lmdeploy.lite.defaults.OFFLOAD_MOD`.') target = OFFLOAD_MOD _to_device(model, target, device) handles = [] for module in model.modules(): if isinstance(module, target_modules): handle1 = module.register_forward_pre_hook(before_forward) handle2 = module.register_forward_hook(after_forward) handles.extend([handle1, handle2]) try: yield finally: for handle in handles: handle.remove() model.to('cpu') torch.cuda.empty_cache() @contextmanager def memory_efficient_inference(model: nn.Module, offload: bool = True, device: str = 'cuda') -> None: """Memory efficient inference context manager. Moves model to device for inference, with option to offload specific modules. Args: model (nn.Module): Model for inference offload (bool): Whether to offload modules device (str): Device for inference Yields: None """ if offload: warnings.warn('Using offload mode - modules defined in OFFLOAD_MOD ' 'will be moved to GPU during forward pass only.') warnings.warn( 'Using offload mode will incur performance penalty due to ' 'frequent CPU-GPU data transfers.') with torch.inference_mode(): with offload_kv_cache(model, device): with offload_weights(model, device): yield else: model.to(device) with torch.inference_mode(): yield