"deploy/operator/internal/discovery/resource.go" did not exist on "043c80c4b3413fc0ed7d3692d328a83ed5a5c89f"
profile_sla.py 31.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import asyncio
17
18
19
20
21
22
import logging
import math
import os

import numpy as np
import yaml
23

24
from benchmarks.profiler.utils.aiperf import benchmark_decode, benchmark_prefill
25
26
from benchmarks.profiler.utils.config import generate_dgd_config_with_planner
from benchmarks.profiler.utils.config_modifiers import CONFIG_MODIFIERS
27
from benchmarks.profiler.utils.estimate_perf import AIConfiguratorPerfEstimator
28
29
30
31
32
from benchmarks.profiler.utils.plot import (
    plot_decode_performance,
    plot_prefill_performance,
)
from benchmarks.profiler.utils.profile_cache import (
33
34
35
36
    check_decode_results_exist,
    check_prefill_results_exist,
    load_existing_decode_results,
    load_existing_prefill_results,
37
)
38
from benchmarks.profiler.utils.profile_decode import (
39
    get_num_request_range,
40
41
42
43
44
45
46
    profile_decode,
    profile_decode_aiconfigurator,
)
from benchmarks.profiler.utils.profile_prefill import (
    profile_prefill,
    profile_prefill_aiconfigurator,
)
47
from benchmarks.profiler.utils.profiler_argparse import create_profiler_parser
48
49
50
51
from deploy.utils.dynamo_deployment import (
    DynamoDeploymentClient,
    cleanup_remaining_deployments,
)
52
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES
53

54
55
56
57
58
59
60
61
62
63
64
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


65
66
67
async def run_profile(args):
    # List to track all created deployment clients for cleanup in case of failure
    deployment_clients = []
68

69
70
71
72
    # Inherit aic_backend from backend if not explicitly set
    if not args.aic_backend:
        args.aic_backend = args.backend

73
    try:
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        # Log MoE model support
        if args.is_moe_model:
            logger.info(
                "MoE (Mixture of Experts) model profiling, sweeping TEP size for prefill and DEP size for decode"
            )
            assert args.backend in [
                "sglang"
            ], "MoE model support is only available for SGLang"
            assert (
                not args.use_ai_configurator
            ), "MoE model is not supported in ai-configurator"
        else:
            logger.info(
                "Standard dense model profiling, sweeping TP size for both prefill and decode"
            )

90
91
92
93
        config_modifier = CONFIG_MODIFIERS[args.backend]

        with open(args.config, "r") as f:
            config = yaml.safe_load(f)
94

95
96
97
98
99
        config = config_modifier.update_model(config, args.model)
        if args.dgd_image:
            config = config_modifier.update_image(config, args.dgd_image)
            logger.info(f"Using DGD image: {args.dgd_image}")

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        if args.is_moe_model:
            # For MoE models, use range with stride of num_gpus_per_node
            profile_num_gpus = list(
                range(
                    args.min_num_gpus_per_engine,
                    args.max_num_gpus_per_engine + 1,
                    args.num_gpus_per_node,
                )
            )
            logger.info(f"Profiling MoE GPU counts (TEP/DEP): {profile_num_gpus}")
        else:
            # For dense models, use powers of 2
            profile_num_gpus = [
                2**i
                for i in range(int(math.log2(args.max_num_gpus_per_engine)) + 1)
                if args.min_num_gpus_per_engine
                <= 2**i
                <= args.max_num_gpus_per_engine
            ]
            logger.info(f"Profiling dense model GPU counts (TP): {profile_num_gpus}")
120

121
        os.makedirs(args.output_dir, exist_ok=True)
122

123
        model_name = config_modifier.get_model_name(config)
124

125
126
127
128
129
130
131
132
133
134
135
136
        # Log skip behavior
        if args.force_rerun:
            logger.info(
                "Force rerun enabled - will re-run all tests even if results exist"
            )
        elif args.skip_existing_results:
            logger.info(
                "Skip existing results enabled - will skip TP sizes with existing results"
            )
        else:
            logger.info("Skip existing results disabled - will re-run all tests")

137
138
139
140
141
142
143
144
145
        if args.use_ai_configurator:
            if not args.aic_system:
                raise ValueError(
                    "Must provide --aic-system when using --use-ai-configurator."
                )
            if not args.aic_model_name:
                raise ValueError(
                    "Must provide --aic-model-name when using --use-ai-configurator."
                )
146
            if not args.aic_backend_version:
147
                raise ValueError(
148
                    "Must provide --aic-backend-version when using --use-ai-configurator."
149
150
151
152
153
154
                )

            logger.info("Will use aiconfigurator to estimate perf.")
            ai_configurator_perf_estimator = AIConfiguratorPerfEstimator(
                args.aic_model_name,
                args.aic_system.lower(),
155
156
                args.aic_backend,
                args.aic_backend_version,
157
158
            )
        else:
159
            if args.aic_system or args.aic_model_name or args.aic_backend_version:
160
161
162
163
164
                logger.warning(
                    "Will ignore --aic-system, --aic-model-name, and/or --backend-version "
                    "when not using --use-ai-configurator."
                )

165
        # first profile prefill
166
        prefill_num_gpus = []
167
168
169
        prefill_ttft = []
        prefill_thpt_per_gpu = []
        logger.info("Profiling prefill...")
170
171
172
        prefill_config = config_modifier.convert_config(
            config, "prefill", is_moe_model=args.is_moe_model
        )
173
        frontend_port = config_modifier.get_port(config)
174
175
        itl: float | None = None
        thpt_per_gpu: float | None = None
176
177
        for num_gpus in profile_num_gpus:
            logger.info(f"Profiling prefill with {num_gpus} GPUs...")
178

179
            # Check if results already exist for this GPU count
180
181
182
            if (
                args.skip_existing_results
                and not args.force_rerun
183
                and check_prefill_results_exist(args.output_dir, num_gpus, args.isl)
184
            ):
185
186
187
                logger.info(
                    f"Skipping prefill {num_gpus} GPU(s) - results already exist"
                )
188
                ttft, thpt_per_gpu = load_existing_prefill_results(
189
                    args.output_dir, num_gpus, args.isl
190
191
                )
                if ttft is not None and thpt_per_gpu is not None:
192
                    prefill_num_gpus.append(num_gpus)
193
194
195
                    prefill_ttft.append(ttft)
                    prefill_thpt_per_gpu.append(thpt_per_gpu)
                    logger.info(
196
                        f"Loaded existing prefill results: {num_gpus} GPU TTFT={ttft:.2f}ms, throughput={thpt_per_gpu:.2f} tokens/s/GPU"
197
198
                    )
                continue
199

200
201
202
203
204
205
206
207
            if args.is_moe_model:
                prefill_config = config_modifier.set_config_tep_size(
                    prefill_config, num_gpus, args.num_gpus_per_node
                )
            else:
                prefill_config = config_modifier.set_config_tp_size(
                    prefill_config, num_gpus
                )
208
            logger.info(f"Dynamo config: {prefill_config}")
209

210
            work_dir = f"{args.output_dir}/prefill_{num_gpus}gpus"
211
            os.makedirs(work_dir, exist_ok=True)
212

213
214
215
            prefill_config_fn = f"{work_dir}/config.yaml"
            with open(prefill_config_fn, "w") as f:
                yaml.dump(prefill_config, f)
216

217
            ttft = None
218
219
            if args.dry_run:
                logger.info("Skipping deployment creation in dry run mode")
220
221
222
223
            elif args.use_ai_configurator:
                logger.info("Using ai-configurator to estimate prefill latency.")
                perf_dict = ai_configurator_perf_estimator.estimate_prefill_perf(
                    args.isl,
224
                    tp_size=num_gpus,
225
226
227
                )
                ttft = perf_dict["context_latency"]
                logger.info(f"Estimated prefill TTFT: {ttft:.2f}ms")
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
            else:
                client = DynamoDeploymentClient(
                    namespace=args.namespace,
                    base_log_dir=work_dir,
                    model_name=model_name,
                    service_name=args.service_name,
                    frontend_port=frontend_port,
                    deployment_name=prefill_config["metadata"]["name"],
                )
                logger.info(f"Created client with service_name: {client.service_name}")
                deployment_clients.append(client)  # Track for cleanup
                await client.create_deployment(prefill_config_fn)
                logger.info("Waiting for deployment to be ready...")
                await client.wait_for_deployment_ready()
                logger.info("Deployment is ready")

                logger.info("Getting deployment logs...")
                await client.get_deployment_logs()
                logger.info(
                    f"Logs have been saved to {client.base_log_dir / client.deployment_name}"
                )
249

250
                # run ai-perf
251
                base_url = client.get_service_url()
252
253
                ai_perf_artifact_dir = f"{work_dir}/aiperf_isl{args.isl}"
                aiperf_result = benchmark_prefill(
254
                    args.isl,
255
                    ai_perf_artifact_dir,
256
257
258
259
                    model_name,
                    model_name,
                    base_url=base_url,
                )
260
                if aiperf_result is not None:
261
                    ttft = aiperf_result["time_to_first_token"]["avg"]
262

263
264
265
266
                logger.info("Cleaning up deployment...")
                await client.delete_deployment()
                deployment_clients.remove(client)
                logger.info("Deployment deleted")
267

268
            if ttft is not None:
269
                prefill_num_gpus.append(num_gpus)
270
                prefill_ttft.append(ttft)
271
                prefill_thpt_per_gpu.append(args.isl / ttft / num_gpus * 1000)
272

273
        # Plot the results as a 2D scatter plot
274
        if prefill_num_gpus and prefill_ttft and prefill_thpt_per_gpu:
275
            plot_prefill_performance(
276
                prefill_num_gpus,
277
278
279
280
281
                prefill_ttft,
                prefill_thpt_per_gpu,
                args.ttft,
                args.output_dir,
            )
282

283
        # then profile decode
284
        decode_num_gpus = []
285
286
287
288
289
290
        decode_itl = []
        decode_thpt_per_gpu = []
        decode_concurrency = []
        decode_kv_cache_size = []
        decode_results = []  # Store partial results for plotting later
        logger.info("Profiling decode...")
291
292
293
294
295
        decode_config = config_modifier.convert_config(
            config, "decode", is_moe_model=args.is_moe_model
        )
        for num_gpus in profile_num_gpus:
            logger.info(f"Profiling decode with {num_gpus} GPUs...")
296

297
            # Check if results already exist for this GPU count
298
299
300
301
            if (
                args.skip_existing_results
                and not args.force_rerun
                and check_decode_results_exist(
302
                    args.output_dir, num_gpus, args.isl, args.osl
303
304
                )
            ):
305
306
307
                logger.info(
                    f"Skipping decode {num_gpus} GPU(s) - results already exist"
                )
308
                existing_results = load_existing_decode_results(
309
                    args.output_dir, num_gpus, args.isl, args.osl
310
311
312
313
314
315
                )
                if existing_results:
                    # Add existing results to our arrays
                    engine_decode_itl = []
                    engine_decode_thpt_per_gpu = []
                    for itl, thpt_per_gpu, concurrency in existing_results:
316
                        decode_num_gpus.append(num_gpus)
317
318
319
320
321
322
323
324
325
326
327
328
329
                        decode_itl.append(itl)
                        decode_thpt_per_gpu.append(thpt_per_gpu)
                        decode_concurrency.append(concurrency)
                        # We need to get kv_cache_size from existing logs or estimate it
                        estimated_kv_cache = max(
                            100000, concurrency * (args.isl + args.osl) * 2
                        )  # Conservative estimate
                        decode_kv_cache_size.append(estimated_kv_cache)
                        engine_decode_itl.append(itl)
                        engine_decode_thpt_per_gpu.append(thpt_per_gpu)

                    # Store results for plotting
                    decode_results.append(
330
                        (num_gpus, engine_decode_itl, engine_decode_thpt_per_gpu)
331
332
                    )
                    logger.info(
333
                        f"Loaded {len(existing_results)} existing decode results for {num_gpus} GPU(s)"
334
335
                    )
                continue
336

337
338
339
340
341
342
343
344
            if args.is_moe_model:
                decode_config = config_modifier.set_config_dep_size(
                    decode_config, num_gpus, args.num_gpus_per_node
                )
            else:
                decode_config = config_modifier.set_config_tp_size(
                    decode_config, num_gpus
                )
345
            logger.info(f"Dynamo config: {decode_config}")
346

347
            work_dir = f"{args.output_dir}/decode_{num_gpus}gpus"
348
            os.makedirs(work_dir, exist_ok=True)
349

350
351
352
            decode_config_fn = f"{work_dir}/config.yaml"
            with open(decode_config_fn, "w") as f:
                yaml.dump(decode_config, f)
353

354
355
            if args.dry_run:
                logger.info("Skipping deployment creation in dry run mode")
356
357
358
359
360

            elif args.use_ai_configurator:
                # Compute max_concurrency and max_kv_tokens to know which
                # num_request to sweep over.
                max_concurrency = ai_configurator_perf_estimator.get_max_batch_size(
361
                    args.isl, args.osl, tp_size=num_gpus
362
363
364
                )
                max_kv_tokens = max_concurrency * (args.isl + args.osl)

365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
            else:
                client = DynamoDeploymentClient(
                    namespace=args.namespace,
                    base_log_dir=work_dir,
                    model_name=model_name,
                    service_name=args.service_name,
                    frontend_port=frontend_port,
                    deployment_name=decode_config["metadata"]["name"],
                )
                deployment_clients.append(client)  # Track for cleanup
                await client.create_deployment(decode_config_fn)
                logger.info("Waiting for deployment to be ready...")
                await client.wait_for_deployment_ready()
                logger.info("Deployment is ready")

                logger.info("Getting deployment logs...")
                await client.get_deployment_logs()
                logger.info(
                    f"Logs have been saved to {client.base_log_dir / client.deployment_name}"
                )
385

386
387
                # Compute max_concurrency and max_kv_tokens to know which
                # num_request to sweep over.
388
389
                # For MoE models, attention_dp_size = DEP size (num_gpus), for dense models = 1
                attention_dp_size = num_gpus if args.is_moe_model else 1
390
                max_kv_tokens = config_modifier.get_kv_cache_size_from_dynamo_log(
391
392
                    f"{work_dir}/{client.deployment_name}/{WORKER_COMPONENT_NAMES[args.backend].decode_worker_k8s_name.lower()}/0.log",
                    attention_dp_size=attention_dp_size,
393
                )
394
                max_concurrency = max_kv_tokens // (args.isl + args.osl)
395
396

            if not args.dry_run:
397
398
399
400
401
402
                attention_dp_size = num_gpus if args.is_moe_model else 1
                sweep_num_request = get_num_request_range(
                    attention_dp_size,
                    max_concurrency,
                    args.decode_interpolation_granularity,
                )
403
404
405
406
407
408
409
                logger.info(
                    f"Sweeping num_request range based on maximum number of kv tokens: {sweep_num_request}"
                )

                engine_decode_itl = []
                engine_decode_thpt_per_gpu = []
                for num_request in sweep_num_request:
410
411
412
413
414
415
416
417
                    itl = thpt_per_gpu = None
                    if args.use_ai_configurator:
                        logger.info("Using ai-configurator to estimate decode latency.")
                        perf_dict = ai_configurator_perf_estimator.estimate_perf(
                            args.isl,
                            args.osl,
                            num_request,
                            mode="decode",
418
                            tp_size=num_gpus,
419
420
421
422
423
424
425
426
427
428
                        )

                        itl = perf_dict["tpot"]
                        thpt_per_gpu = perf_dict["tokens/s/gpu"]
                        logger.info(f"Estimated decode ITL: {itl:.2f}ms")
                        logger.info(
                            f"Estimated decode throughput per GPU: {thpt_per_gpu:.2f} tokens/s/GPU"
                        )
                    else:
                        base_url = client.get_service_url()
429
430
                        ai_perf_artifact_dir = f"{work_dir}/aiperf_request{num_request}_isl{args.isl}_osl{args.osl}_n{num_request}"
                        aiperf_result = benchmark_decode(
431
432
433
                            args.isl,
                            args.osl,
                            num_request,
434
                            ai_perf_artifact_dir,
435
436
437
                            model_name,
                            model_name,
                            base_url=base_url,
438
                        )
439
                        if aiperf_result is not None:
440
                            itl = aiperf_result["inter_token_latency"]["avg"]
441
                            thpt_per_gpu = (
442
                                aiperf_result["output_token_throughput"]["avg"]
443
                                / num_gpus
444
445
446
                            )

                    if itl is not None and thpt_per_gpu is not None:
447
448
                        engine_decode_itl.append(itl)
                        engine_decode_thpt_per_gpu.append(thpt_per_gpu)
449
                        decode_num_gpus.append(num_gpus)
450
451
452
453
                        decode_itl.append(itl)
                        decode_thpt_per_gpu.append(thpt_per_gpu)
                        decode_concurrency.append(num_request)
                        decode_kv_cache_size.append(max_kv_tokens)
454

455
456
                # Store partial results for plotting later
                decode_results.append(
457
                    (num_gpus, engine_decode_itl, engine_decode_thpt_per_gpu)
458
                )
459

460
461
462
463
464
465
            if not args.dry_run and not args.use_ai_configurator:
                logger.info("Cleaning up deployment...")
                await client.delete_deployment()
                deployment_clients.remove(client)
                logger.info("Deployment deleted")

466
467
468
469
        # Plot all decode results after profiling is complete
        if decode_results:
            plot_decode_performance(decode_results, args.itl, args.output_dir)

470
471
472
473
474
        if args.dry_run:
            logger.info("Skipping recommendations in dry run mode")
        else:
            logger.info("Analyzing results and generate recommendations...")
            # Safety guards: no results → exit early with a clear message
475
            if not (prefill_num_gpus and prefill_ttft and prefill_thpt_per_gpu):
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
                logger.error("No prefill results produced; skipping recommendations.")

            # select best tp size for prefill
            if min(prefill_ttft) > args.ttft:
                logger.info(
                    "No TP size satisfies the TTFT requirement, please try a smaller model or a more powerful GPU SKU"
                )
                selected_prefill_idx = int(np.argmin(np.array(prefill_ttft)))
            else:
                valid_indices = [
                    i for i, ttft in enumerate(prefill_ttft) if ttft <= args.ttft
                ]
                # Among valid TP sizes, select the one with highest throughput per GPU
                valid_thpts = [prefill_thpt_per_gpu[i] for i in valid_indices]
                max_thpt_idx = valid_indices[int(np.argmax(valid_thpts))]
                selected_prefill_idx = max_thpt_idx
            logger.info(
493
                f"Suggested number of GPUs for prefill: {prefill_num_gpus[selected_prefill_idx]} (TTFT {prefill_ttft[selected_prefill_idx]:.2f} ms, throughput {prefill_thpt_per_gpu[selected_prefill_idx]:.2f} tokens/s/GPU)"
494
            )
495

496
497
498
499
500
501
502
503
            # scale up if estimated TTFT is 120% of target TTFT
            prefill_queue_size_upper_bound = max(
                0.1, args.ttft * 1.2 / prefill_ttft[selected_prefill_idx] - 1
            )
            # scale down if estimated TTFT is 80% of target TTFT
            prefill_queue_size_lower_bound = max(
                0.1, args.ttft * 0.8 / prefill_ttft[selected_prefill_idx] - 1
            )
504
            logger.info(
505
                f"Suggested planner upper/lower bound for prefill queue size: {prefill_queue_size_upper_bound:.2f}/{prefill_queue_size_lower_bound:.2f}"
506
            )
507

508
            # select best gpu count for decode
509
            if not (
510
                decode_num_gpus
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
                and decode_itl
                and decode_thpt_per_gpu
                and decode_concurrency
                and decode_kv_cache_size
            ):
                logger.error("No decode results produced; skipping recommendations.")
                return
            if min(decode_itl) > args.itl:
                logger.info(
                    "No TP size satisfies the ITL requirement, please try a smaller model or a more powerful GPU SKU"
                )
                selected_decode_idx = int(np.argmin(np.array(decode_itl)))
            else:
                valid_indices = [
                    i for i, itl in enumerate(decode_itl) if itl <= args.itl
                ]
                # Among valid TP sizes, select the one with highest throughput per GPU
                valid_thpts = [decode_thpt_per_gpu[i] for i in valid_indices]
                max_thpt_idx = valid_indices[int(np.argmax(valid_thpts))]
                selected_decode_idx = max_thpt_idx
            logger.info(
532
                f"Suggested number of GPUs for decode: {decode_num_gpus[selected_decode_idx]} (ITL {decode_itl[selected_decode_idx]:.2f} ms, throughput {decode_thpt_per_gpu[selected_decode_idx]:.2f} tokens/s/GPU)"
533
            )
534

535
536
537
538
539
540
541
            # calculate kv cache utlization for the selected TP and concurrency
            selected_decode_kv_cache_utilization = (
                decode_concurrency[selected_decode_idx]
                * (args.isl + (args.osl / 2))
                / decode_kv_cache_size[selected_decode_idx]
            )
            # set a +- 20% range for the kv cache utilization
542
            logger.info(
543
                f"Suggested planner upper/lower bound for decode kv cache utilization: {min(1, selected_decode_kv_cache_utilization + 0.2):.2f}/{max(0.1, selected_decode_kv_cache_utilization - 0.2):.2f}"
544
545
            )

546
        if args.dry_run:
547
548
549
            # use min value for prefill and decode GPU counts
            prefill_num_gpus = [args.min_num_gpus_per_engine]
            decode_num_gpus = [args.min_num_gpus_per_engine]
550
551
            selected_prefill_idx = 0
            selected_decode_idx = 0
552

553
554
        # interpolate ISL - TTFT with best prefill GPU count
        best_prefill_gpus = prefill_num_gpus[selected_prefill_idx]
555
        logger.info(
556
            f"Profiling prefill under best {best_prefill_gpus} GPU(s) with different ISL..."
557
        )
558
559
        prefill_config = config_modifier.convert_config(
            config, "prefill", is_moe_model=args.is_moe_model
560
        )
561
562
563
564
565
566
567
568
        if args.is_moe_model:
            prefill_config = config_modifier.set_config_tep_size(
                prefill_config, best_prefill_gpus, args.num_gpus_per_node
            )
        else:
            prefill_config = config_modifier.set_config_tp_size(
                prefill_config, best_prefill_gpus
            )
569
        logger.info(f"Dynamo config: {prefill_config}")
570

571
572
        work_dir = f"{args.output_dir}/selected_prefill_interpolation"
        os.makedirs(work_dir, exist_ok=True)
573

574
575
576
577
        prefill_config_fn = f"{work_dir}/config.yaml"
        with open(prefill_config_fn, "w") as f:
            yaml.dump(prefill_config, f)

578
579
        if args.dry_run:
            logger.info("Skipping deployment creation in dry run mode")
580
581
582
        elif args.use_ai_configurator:
            profile_prefill_aiconfigurator(
                work_dir,
583
                best_prefill_gpus,  # num_gpus
584
585
586
                args.max_context_length,
                args.prefill_interpolation_granularity,
                ai_configurator_perf_estimator,
587
                tp_size=best_prefill_gpus,
588
            )
589
590
591
592
593
594
595
596
        else:
            client = DynamoDeploymentClient(
                namespace=args.namespace,
                base_log_dir=work_dir,
                model_name=model_name,
                service_name=args.service_name,
                frontend_port=frontend_port,
                deployment_name=prefill_config["metadata"]["name"],
597
            )
598
599
600
601
602
603
            deployment_clients.append(client)  # Track for cleanup
            await client.create_deployment(prefill_config_fn)
            logger.info("Waiting for deployment to be ready...")
            try:
                await client.wait_for_deployment_ready()
                logger.info("Deployment is ready")
604

605
606
607
                skip_profile = False
            except TimeoutError:
                logger.error(
608
                    "Deployment or model failed to become ready within timeout, skipping profiling"
609
610
                )
                skip_profile = True
611

612
613
614
615
616
617
            if not skip_profile:
                logger.info("Getting deployment logs...")
                await client.get_deployment_logs()
                logger.info(
                    f"Logs have been saved to {client.base_log_dir / client.deployment_name}"
                )
618

619
            base_url = client.get_service_url()
620

621
622
623
624
625
            profile_prefill(
                work_dir,
                model_name,
                model_name,
                base_url,
626
                best_prefill_gpus,
627
628
629
                args.max_context_length,
                args.prefill_interpolation_granularity,
            )
630

631
632
633
634
            logger.info("Cleaning up deployment...")
            await client.delete_deployment()
            deployment_clients.remove(client)
            logger.info("Deployment deleted")
635

636
637
638
639
640
641
642
643
644
645
646
        # interpolate ITL - Active_KV_Cache - Decode_Context_Length with best decode GPU count
        best_decode_gpus = decode_num_gpus[selected_decode_idx]
        logger.info(f"Profiling decode with {best_decode_gpus} GPUs...")
        if args.is_moe_model:
            decode_config = config_modifier.set_config_dep_size(
                decode_config, best_decode_gpus, args.num_gpus_per_node
            )
        else:
            decode_config = config_modifier.set_config_tp_size(
                decode_config, best_decode_gpus
            )
647
        logger.info(f"Dynamo config: {decode_config}")
648

649
650
651
652
653
654
655
        work_dir = f"{args.output_dir}/selected_decode_interpolation"
        os.makedirs(work_dir, exist_ok=True)

        decode_config_fn = f"{work_dir}/config.yaml"
        with open(decode_config_fn, "w") as f:
            yaml.dump(decode_config, f)

656
657
        if args.dry_run:
            logger.info("Skipping deployment creation in dry run mode")
658
        elif args.use_ai_configurator:
659
660
            # For MoE models, attention_dp_size = DEP size (best_decode_gpus), for dense models = 1
            attention_dp_size = best_decode_gpus if args.is_moe_model else 1
661
            max_kv_tokens = ai_configurator_perf_estimator.get_max_kv_tokens(
662
                args.isl, args.osl, tp_size=best_decode_gpus
663
664
665
            )
            profile_decode_aiconfigurator(
                work_dir,
666
                best_decode_gpus,  # num_gpus
667
668
669
670
                max_kv_tokens,
                args.max_context_length,
                args.decode_interpolation_granularity,
                ai_configurator_perf_estimator,
671
672
                attention_dp_size,
                tp_size=best_decode_gpus,
673
            )
674
675
676
677
678
679
680
681
682
683
684
685
686
687
        else:
            client = DynamoDeploymentClient(
                namespace=args.namespace,
                base_log_dir=work_dir,
                model_name=model_name,
                service_name=args.service_name,
                frontend_port=frontend_port,
                deployment_name=decode_config["metadata"]["name"],
            )
            deployment_clients.append(client)  # Track for cleanup
            await client.create_deployment(decode_config_fn)
            logger.info("Waiting for deployment to be ready...")
            await client.wait_for_deployment_ready()
            logger.info("Deployment is ready")
688

689
690
691
692
693
            logger.info("Getting deployment logs...")
            await client.get_deployment_logs()
            logger.info(
                f"Logs have been saved to {client.base_log_dir / client.deployment_name}"
            )
694

695
696
            # For MoE models, attention_dp_size = DEP size (best_decode_gpus), for dense models = 1
            attention_dp_size = best_decode_gpus if args.is_moe_model else 1
697
            max_kv_tokens = config_modifier.get_kv_cache_size_from_dynamo_log(
698
699
                f"{work_dir}/{client.deployment_name}/{WORKER_COMPONENT_NAMES[args.backend].decode_worker_k8s_name.lower()}/0.log",
                attention_dp_size=attention_dp_size,
700
701
702
703
704
705
706
707
708
            )

            base_url = client.get_service_url()

            profile_decode(
                work_dir,
                model_name,
                model_name,
                base_url,
709
                best_decode_gpus,
710
711
712
                max_kv_tokens,
                args.max_context_length,
                args.decode_interpolation_granularity,
713
                attention_dp_size,
714
            )
715

716
717
718
719
            logger.info("Cleaning up deployment...")
            await client.delete_deployment()
            deployment_clients.remove(client)
            logger.info("Deployment deleted")
720

721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
        # generate DGD with planner based on profiling results
        config = generate_dgd_config_with_planner(
            config_path=args.config,
            config_modifier=config_modifier,
            best_prefill_gpus=best_prefill_gpus,
            best_decode_gpus=best_decode_gpus,
            output_dir=args.output_dir,
            args=args,
            is_moe_model=args.is_moe_model,
            num_gpus_per_node=args.num_gpus_per_node,
        )
        logger.info(f"Final DGD config with planner: {config}")

        # save DGD config with planner
        with open(f"{args.output_dir}/config_with_planner.yaml", "w") as f:
            yaml.dump(config, f)

738
739
740
741
742
743
744
745
746
747
748
    except Exception as e:
        logger.error(f"Profile job failed with error: {e}")
        raise
    finally:
        # Always clean up any remaining deployments, even if the job failed
        logger.info("Performing final cleanup of any remaining deployments...")
        await cleanup_remaining_deployments(deployment_clients, args.namespace)
        logger.info("Final cleanup completed.")


if __name__ == "__main__":
749
    args = create_profiler_parser()
750

751
752
753
754
755
756
757
758
759
760
    # setup file logging
    os.makedirs(args.output_dir, exist_ok=True)
    log_file_handler = logging.FileHandler(f"{args.output_dir}/profile_sla.log")
    log_file_handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
    )
    log_file_handler.setFormatter(formatter)
    logger.addHandler(log_file_handler)

761
    asyncio.run(run_profile(args))