search_space_autogen.py 6.76 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0

import argparse
import logging
import math
import os

import yaml

from benchmarks.profiler.utils.config_modifiers import CONFIG_MODIFIERS
12
from benchmarks.profiler.utils.model_info import ModelInfo, get_model_info
13
14
15
16
17
18
19
20
21
22
23
24
25
from deploy.utils.gpu_inventory import get_gpu_summary

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

MODEL_GPU_MEM_FRAC_MAX = 0.9
26
27
28

# for MoE models, we sweep up to number of GPUs that can hold 8x the model weights
MOE_MODEL_MAX_NUM_GPU_FACTOR = 8
29
30
31
32
33
34
35


def auto_generate_search_space(args: argparse.Namespace) -> None:
    config_modifier = CONFIG_MODIFIERS[
        args.backend
    ]  # args.backend is already validated in argparse

36
37
38
39
40
41
42
43
    # first get the config
    if not args.config:
        # modify config file from default config file
        logger.info("DGD config file not provided, using default config file")
        config = config_modifier.load_default_config()
    else:
        with open(args.config, "r") as f:
            config = yaml.safe_load(f)
44

45
    if args.model:
46
        logger.info(f"Updating model in DGD config file to {args.model}")
47
48
49
50
51
52
53
54
55
56
57
        if args.model_cache_pvc_name:
            config = config_modifier.update_model_from_pvc(
                config,
                args.model,
                args.model_cache_pvc_name,
                args.model_cache_pvc_mount_path,
                args.model_cache_pvc_path,
            )
        else:
            # Non-PVC: workers download from HF, so model_path == model_name
            config = config_modifier.update_model(config, args.model, args.model)
58
59
60
61
        if args.dgd_image:
            logger.info(f"Updating DGD image to {args.dgd_image}")
            config = config_modifier.update_image(config, args.dgd_image)

62
63
64
65
66
67
68
        config_fn = f"{args.output_dir}/disagg_config.yaml"
        logger.info(f"Saving generated disagg DGD config for profiling to {config_fn}")
        os.makedirs(args.output_dir, exist_ok=True)
        with open(config_fn, "w") as f:
            yaml.dump(config, f)
        args.config = config_fn

69
70
    # get model info and update args
    model_info: ModelInfo | None = None
71
72
73
74
75
76
77
78
79
80
81
    model_name_or_path = ""
    if args.model:
        # prioritize using model cache in PVC over downloading from HF
        if args.model_cache_pvc_name:
            # Keep consistent path normalization with config mutation logic
            model_name_or_path = config_modifier._normalize_model_path(
                args.model_cache_pvc_mount_path, args.model_cache_pvc_path
            )
        else:
            model_name_or_path = args.model
    else:
82
        # get the model name from config
83
84
        args.model, args.model_path = config_modifier.get_model_name(config)
        model_name_or_path = args.model_path
85
86
87
88
89
90
91
92
93
94
    logger.info(f"Getting model info for {args.model} at {model_name_or_path}...")
    try:
        model_info = get_model_info(model_name_or_path)
    except Exception as e:
        # Common in dry-run mode when the PVC isn't mounted locally.
        logger.warning(
            f"Failed to load model info from local path '{model_name_or_path}': {e}. "
            f"Trying to download from HF for '{args.model}'."
        )
        model_info = get_model_info(args.model)
95
96
97
98
99
100
101
102
103
104

    num_experts_str = (
        f", num_experts={model_info.num_experts}"
        if model_info.num_experts is not None
        else ""
    )
    logger.info(
        f"Model {args.model} has size {model_info.model_size}, is_moe={model_info.is_moe}, and max_context_length={model_info.max_context_length}{num_experts_str}"
    )
    args.model_info = model_info
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    # now determine the search space
    if args.enable_gpu_discovery:
        if (
            args.min_num_gpus_per_engine == 0
            or args.max_num_gpus_per_engine == 0
            or args.num_gpus_per_node == 0
        ):
            if not args.model:
                # TODO: get model info provided DGD config
                error_msg = "No model provided, cannot auto-generate GPU search space. Please provide `--model` or GPU info"
                logger.error(error_msg)
                raise RuntimeError(error_msg)

            logger.info("Getting GPU info from k8s cluster...")
            gpu_info = get_gpu_summary()
            logger.info(
                f"Cluster has {gpu_info['gpus_per_node']}x{gpu_info['model']} GPUs per node with {gpu_info['vram']} VRAM"
            )

            # model_info should be set by now (checked above), but mypy needs explicit verification
            assert (
                model_info is not None
            ), "model_info must be set when model is provided"

            vram_mib = int(gpu_info["vram"])  # type: ignore[call-overload]
            gpus_per_node = int(gpu_info["gpus_per_node"])  # type: ignore[call-overload]

            min_gpu = math.ceil(
                model_info.model_size / MODEL_GPU_MEM_FRAC_MAX / vram_mib
            )
            if not model_info.is_moe:
                max_gpu = gpus_per_node
            else:
                max_gpu = max(min_gpu * MOE_MODEL_MAX_NUM_GPU_FACTOR, gpus_per_node)
            if min_gpu > max_gpu:
                error_msg = f"No valid GPU configuration found for model {args.model} on the cluster with {gpu_info['gpus_per_node']}x{gpu_info['model']} GPUs per node"
                logger.error(error_msg)
                raise RuntimeError(error_msg)

            logger.info(
                f"Auto-generated search space for model {args.model} on the cluster with {gpu_info['gpus_per_node']}x{gpu_info['model']} GPUs per node: {min_gpu} to {max_gpu}"
            )
            args.min_num_gpus_per_engine = min_gpu
            args.max_num_gpus_per_engine = max_gpu
            args.num_gpus_per_node = gpus_per_node  # type: ignore[assignment]
    else:
        # use default values for GPUs
        if args.min_num_gpus_per_engine == 0:
            logger.warning(
                "GPU discover is disabled and min_num_gpus_per_engine is not specified, setting to 1"
            )
            args.min_num_gpus_per_engine = 1
        if args.max_num_gpus_per_engine == 0:
            logger.warning(
                "GPU discover is disabled and max_num_gpus_per_engine is not specified, setting to 4"
            )
            args.max_num_gpus_per_engine = 4
        if args.num_gpus_per_node == 0:
            logger.warning(
                "GPU discover is disabled and num_gpus_per_node is not specified, setting to 8"
            )
            args.num_gpus_per_node = 8
168
    return