utils.py 811 Bytes
Newer Older
1
"""Utils for model executor."""
2
from typing import Any, Dict, Optional
3
4
5

import torch

6
7
from vllm.utils import seed_everything

8
9

def set_random_seed(seed: int) -> None:
10
    seed_everything(seed)
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


def set_weight_attrs(
    weight: torch.Tensor,
    weight_attrs: Optional[Dict[str, Any]],
):
    """Set attributes on a weight tensor.

    This method is used to set attributes on a weight tensor. This method
    will not overwrite existing attributes.

    Args:
        weight: The weight tensor.
        weight_attrs: A dictionary of attributes to set on the weight tensor.
    """
    if weight_attrs is None:
        return
    for key, value in weight_attrs.items():
        assert not hasattr(
            weight, key), (f"Overwriting existing tensor attribute: {key}")
        setattr(weight, key, value)