model_config.py 1.1 KB
Newer Older
Lianmin Zheng's avatar
Lianmin Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os
from typing import Optional, Union

import torch
from sglang.srt.hf_transformers_utils import get_config, get_context_length


class ModelConfig:
    def __init__(
        self,
        path: str,
        trust_remote_code: bool = True,
        revision: Optional[str] = None,
    ) -> None:
        self.path = path
        self.trust_remote_code = trust_remote_code
        self.revision = revision
        self.hf_config = get_config(self.path, trust_remote_code, revision)

        # Unify the config keys for hf_config
        self.context_len = get_context_length(self.hf_config)
        self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
        self.num_attention_heads = self.hf_config.num_attention_heads
24
25
26
        self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
        if self.num_key_value_heads is None:
            self.num_key_value_heads = self.num_attention_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
27
28
29
        self.hidden_size = self.hf_config.hidden_size
        self.num_hidden_layers = self.hf_config.num_hidden_layers
        self.vocab_size = self.hf_config.vocab_size