model_info.py 8.1 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from typing import Optional, Union

from huggingface_hub import model_info
8
from pydantic import BaseModel
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from transformers import AutoConfig

DTYPE_BYTES_MAP = {
    "F32": 4,  # FP32: 4 bytes per parameter
    "BF16": 2,  # BF16: 2 bytes per parameter
    "F16": 2,  # FP16: 2 bytes per parameter
    "F8_E4M3": 1,  # FP8: 1 byte per parameter
    "F8_E5M2": 1,  # FP8: 1 byte per parameter
    "F8_E8M0": 1,  # FP8: 1 byte per parameter
    "I8": 1,  # INT8: 1 byte per parameter
    "I4": 0.5,  # INT4: 0.5 bytes per parameter
}

CONTEXT_LENGTH_ATTRS = [
    "max_position_embeddings",  # Most common (BERT, GPT, LLaMA, etc.)
    "n_positions",  # GPT-2, GPT-Neo
    "max_sequence_length",  # Some models
    "seq_length",  # Some older models
    "model_max_length",  # Some tokenizer configs
    "sliding_window",  # Mistral with sliding window attention
]

# only for MLA + MoE models, treat other MoE models as dense models
MOE_ARCHITECTURES = {"DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"}


def get_local_model_weight_size(
    model_path: Union[str, Path],
) -> float:
    """Return model size in MB by scanning local directory."""
    model_path = Path(model_path)

    if not model_path.exists():
        raise FileNotFoundError(f"Model path does not exist: {model_path}")

    if not model_path.is_dir():
        raise ValueError(f"Model path is not a directory: {model_path}")

    # Weight file extensions to look for
    weight_extensions = [".safetensors", ".bin", ".pt", ".pth"]

    total_size_bytes = 0
    for file_path in model_path.rglob("*"):
        if file_path.is_file() and any(
            str(file_path).endswith(ext) for ext in weight_extensions
        ):
            total_size_bytes += file_path.stat().st_size

    return total_size_bytes / (1024**2)


def get_model_weight_size_from_hub(
    model_name: str,
    token: Optional[str] = None,
) -> float:
    """Return model size in MB by querying Hugging Face Hub API."""
    try:
        info = model_info(model_name, token=token)

        # Filter for model weight files (safetensors or pytorch bin files)
        # Also filter out files with None size
        weight_extensions = [".safetensors", ".bin", ".pt", ".pth"]
        total_size_bytes = 0

        if info.siblings is not None:
            for sibling in info.siblings:
                if any(sibling.rfilename.endswith(ext) for ext in weight_extensions):
                    if sibling.size is not None:
                        total_size_bytes += sibling.size

        # If no file sizes were available, try to estimate from safetensors metadata
        if total_size_bytes == 0 and info.safetensors is not None:
            # SafeTensors info gives parameter counts per dtype
            for dtype, param_count in info.safetensors.parameters.items():
                bytes_per_param = DTYPE_BYTES_MAP.get(
                    dtype, 2
                )  # Default to 2 bytes (FP16/BF16)
                total_size_bytes += int(param_count * bytes_per_param)

        return total_size_bytes / (1024**2)
    except Exception as e:
        raise RuntimeError(f"Failed to get model info from Hub: {e}")


def get_model_weight_size(
    model_name_or_path: Union[str, Path],
) -> float:
    """Return model size in MB (auto-detects local vs HF Hub)."""
    path = Path(model_name_or_path)

    if path.exists() and path.is_dir():
        # Local model
        return get_local_model_weight_size(model_name_or_path)
    else:
        # HF Hub model
        return get_model_weight_size_from_hub(str(model_name_or_path))


107
108
109
110
111
112
113
114
115
116
117
class ModelInfo(BaseModel):
    model_size: float
    architecture: str
    is_moe: bool
    max_context_length: Optional[int] = None
    num_experts: Optional[int] = None
    intermediate_size: Optional[int] = None
    num_kv_heads: Optional[int] = None
    quantization_block_size: Optional[int] = None


118
119
120
def get_model_info(
    model_name_or_path: Union[str, Path],
    trust_remote_code: bool = False,
121
) -> ModelInfo:
122
123
124
125
126
127
128
    model_size = get_model_weight_size(model_name_or_path)

    config = AutoConfig.from_pretrained(
        model_name_or_path,
        trust_remote_code=trust_remote_code,
    )

129
130
131
    architecture = config.architectures[0]
    if architecture in MOE_ARCHITECTURES:
        is_moe = True
132
    else:
133
        is_moe = False
134
135
136
137
138
139
140
141
142
143
144

    # Detect max context length from config
    # Different models use different attribute names for max context length
    max_context_length = None
    for attr in CONTEXT_LENGTH_ATTRS:
        if hasattr(config, attr):
            value = getattr(config, attr)
            if value is not None:
                max_context_length = value
                break

145
146
147
    # Detect number of experts for MoE models
    # Different models use different attribute names
    num_experts = None
148
    if is_moe:
149
150
151
152
153
154
155
156
157
158
159
160
        expert_attrs = [
            "n_routed_experts",  # DeepSeek V3/R1
            "num_local_experts",  # Mixtral, Qwen
            "num_experts",  # Generic
        ]
        for attr in expert_attrs:
            if hasattr(config, attr):
                value = getattr(config, attr)
                if value is not None:
                    num_experts = value
                    break

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    # Detect intermediate size (FFN hidden dimension)
    intermediate_size = None
    intermediate_attrs = [
        "intermediate_size",  # Most common (BERT, LLaMA, etc.)
        "ffn_dim",  # Some transformer models
    ]
    for attr in intermediate_attrs:
        if hasattr(config, attr):
            value = getattr(config, attr)
            if value is not None:
                intermediate_size = value
                break

    # Detect number of key-value heads (for GQA)
    num_kv_heads = None
    kv_head_attrs = [
        "num_key_value_heads",  # LLaMA 2/3, Mistral, etc.
        "num_kv_heads",  # Alternative name
    ]
    for attr in kv_head_attrs:
        if hasattr(config, attr):
            value = getattr(config, attr)
            if value is not None:
                num_kv_heads = value
                break
    # If not found, check if it equals num_attention_heads (standard MHA)
    if num_kv_heads is None and hasattr(config, "num_attention_heads"):
        num_kv_heads = config.num_attention_heads

    # Detect quantization block size
    quantization_block_size = None
    if hasattr(config, "quantization_config"):
        quant_config = config.quantization_config
        if isinstance(quant_config, dict):
            # Check for common quantization block size attributes
            quantization_block_size = (
                quant_config.get("weight_block_size")
                or quant_config.get("block_size")
                or quant_config.get("group_size")
                or quant_config.get("q_group_size")
            )
        elif quant_config is not None:
            # Handle object-based quantization config
            for attr in [
                "weight_block_size",
                "block_size",
                "group_size",
                "q_group_size",
            ]:
                if hasattr(quant_config, attr):
                    value = getattr(quant_config, attr)
                    if value is not None:
                        quantization_block_size = value
                        break

        # Handle case where block size is a list (e.g., [128, 128] for [input, output] block sizes)
        if (
            isinstance(quantization_block_size, list)
            and len(quantization_block_size) > 0
        ):
            quantization_block_size = max(quantization_block_size)

    return ModelInfo(
        model_size=model_size,
        architecture=architecture,
        is_moe=is_moe,
        max_context_length=max_context_length,
        num_experts=num_experts,
        intermediate_size=intermediate_size,
        num_kv_heads=num_kv_heads,
        quantization_block_size=quantization_block_size,
    )
233
234
235
236
237
238
239
240
241
242


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    args = parser.parse_args()

    print(get_model_info(args.model))