"""Utils for model executor.""" from typing import Any, Dict, Optional import torch from vllm.platforms import current_platform from vllm.utils import seed_everything def set_random_seed(seed: int) -> None: seed_everything(seed) def set_weight_attrs( weight: torch.Tensor, weight_attrs: Optional[Dict[str, Any]], ): """Set attributes on a weight tensor. This method is used to set attributes on a weight tensor. This method will not overwrite existing attributes. Args: weight: The weight tensor. weight_attrs: A dictionary of attributes to set on the weight tensor. """ if weight_attrs is None: return for key, value in weight_attrs.items(): assert not hasattr( weight, key), (f"Overwriting existing tensor attribute: {key}") # NOTE(woosuk): During weight loading, we often do something like: # narrowed_tensor = param.data.narrow(0, offset, len) # narrowed_tensor.copy_(real_weight) # expecting narrowed_tensor and param.data to share the same storage. # However, on TPUs, narrowed_tensor will lazily propagate to the base # tensor, which is param.data, leading to the redundant memory usage. # This sometimes causes OOM errors during model loading. To avoid this, # we sync the param tensor after its weight loader is called. # TODO(woosuk): Remove this hack once we have a better solution. if current_platform.is_tpu() and key == "weight_loader": value = _make_synced_weight_loader(value) setattr(weight, key, value) def _make_synced_weight_loader(original_weight_loader): def _synced_weight_loader(param, *args, **kwargs): original_weight_loader(param, *args, **kwargs) torch._sync(param) return _synced_weight_loader def pad_weight(weight: torch.Tensor, num_pad: int, pad_dim: int = 0): if weight.dim() == 1: padding = torch.zeros(num_pad, dtype=weight.dtype, device=weight.device) padded_weight = torch.cat([weight, padding], dim=0) elif weight.dim() == 2: if pad_dim == 0: padding = torch.zeros(num_pad, weight.shape[1], dtype=weight.dtype, device=weight.device) padded_weight = torch.cat([weight, padding], dim=0) elif pad_dim == 1: padding = torch.zeros(weight.shape[0], num_pad, dtype=weight.dtype, device=weight.device) padded_weight = torch.cat([weight, padding], dim=1) else: raise ValueError("pad_dim must be 0 or 1") else: raise ValueError("Weight tensor must be 1D or 2D") padded_weight = padded_weight.contiguous() return padded_weight def gemm_bank_conf(weight): is_mul_of_2048 = weight % 2048 == 0 is_power_of_two = (weight & (weight - 1)) == 0 and weight != 0 if is_mul_of_2048 and is_power_of_two: return True else: return False