render.py 3.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import re
from pathlib import Path

import yaml
from jinja2 import Environment, FileSystemLoader


def parse_args():
    parser = argparse.ArgumentParser(
        description="Renders dynamo Dockerfiles from templates"
    )
    parser.add_argument(
        "--framework",
        type=str,
        default="vllm",
        help="Dockerfile framework to use [dynamo, vllm, sglang, trtllm]",
    )
    parser.add_argument(
        "--target",
        type=str,
        default="runtime",
        help="Dockerfile target to use. Non-exhaustive examples: [runtime, dev, local-dev]",
    )
    parser.add_argument(
        "--platform",
        type=str,
        default="amd64",
        help="Dockerfile platform to use. [amdg64, arm64]",
    )
    parser.add_argument(
        "--cuda-version",
        type=str,
        default="12.9",
        help="CUDA version to use. [12.9, 13.0]",
    )
    parser.add_argument("--make-efa", action="store_true", help="Enable AWS EFA")
    parser.add_argument(
43
        "--output-short-filename",
44
        action="store_true",
45
        help="Output filename is rendered.Dockerfile instead of <framework>-<target>-cuda<cuda_version>-<arch>-rendered.Dockerfile",
46
47
48
49
50
51
52
53
54
55
56
    )
    parser.add_argument(
        "--show-result",
        action="store_true",
        help="Prints the rendered Dockerfile to stdout.",
    )
    args = parser.parse_args()
    return args


def validate_args(args):
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    valid_inputs = {
        "vllm": {"runtime", "dev", "local-dev", "framework", "wheel_builder", "base"},
        "trtllm": {"runtime", "dev", "local-dev", "framework", "wheel_builder", "base"},
        "sglang": {"runtime", "dev", "local-dev", "wheel_builder", "base"},
        "dynamo": {"runtime", "dev", "local-dev", "frontend", "wheel_builder", "base"},
    }

    if args.framework in valid_inputs:
        if args.target in valid_inputs[args.framework]:
            return
        raise ValueError(
            f"Invalid input combination: [framework={args.framework},target={args.target}]"
        )

    raise ValueError(
        f"Invalid input combination: [framework={args.framework},target={args.target}]"
    )
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    return


def render(args, context, script_dir):
    env = Environment(
        loader=FileSystemLoader(script_dir), trim_blocks=False, lstrip_blocks=True
    )
    template = env.get_template("Dockerfile.template")
    rendered = template.render(
        context=context,
        framework=args.framework,
        target=args.target,
        platform=args.platform,
        cuda_version=args.cuda_version,
        make_efa=args.make_efa,
    )
    # Replace all instances of 3+ newlines with 2 newlines
    cleaned = re.sub(r"\n{3,}", "\n\n", rendered)

93
    if args.output_short_filename:
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        filename = "rendered.Dockerfile"
    else:
        filename = f"{args.framework}-{args.target}-cuda{args.cuda_version}-{args.platform}-rendered.Dockerfile"

    with open(f"{script_dir}/{filename}", "w") as f:
        f.write(cleaned)

    if args.show_result:
        print("##############")
        print("# Dockerfile #")
        print("##############")
        print(cleaned)
        print("##############")

    print(f"INFO: Generated Dockerfile written to {script_dir}/{filename}")

    return


def main():
    args = parse_args()
    validate_args(args)
116
    script_dir = Path(__file__).parent
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    with open(f"{script_dir}/context.yaml", "r") as f:
        context = yaml.safe_load(f)

    render(args, context, script_dir)

    if args.target == "local-dev":
        print(
            "INFO: Remember to add --build-arg values for USER_UID and USER_GID when building a local-dev image!"
        )
        print(
            "      Recommendation: --build-arg USER_UID=$(id -u) --build-arg USER_GID=$(id -g)"
        )


if __name__ == "__main__":
    main()