utils.py 1.48 KB
Newer Older
1
"""Utils for model executor."""
2
import random
3
import importlib
4
from typing import Any, Dict, Optional
5
6
7
8

import numpy as np
import torch

9
10
11
12
13
14
15
from vllm.config import DeviceConfig, ModelConfig

DEVICE_TO_MODEL_LOADER_MAP = {
    "cuda": "model_loader",
    "neuron": "neuron_model_loader",
}

16
17
18
19
20
21
22

def set_random_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


def set_weight_attrs(
    weight: torch.Tensor,
    weight_attrs: Optional[Dict[str, Any]],
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(
            weight, key), (f"Overwriting existing tensor attribute: {key}")
        setattr(weight, key, value)
44
45
46
47
48
49
50
51
52


def get_model(model_config: ModelConfig, device_config: DeviceConfig,
              **kwargs) -> torch.nn.Module:
    model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
    imported_model_loader = importlib.import_module(
        f"vllm.model_executor.{model_loader_module}")
    get_model_fn = imported_model_loader.get_model
    return get_model_fn(model_config, device_config, **kwargs)