render.py 5.59 KB
Newer Older
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import re
from pathlib import Path

import yaml
10
from jinja2 import Environment, FileSystemLoader, StrictUndefined
11
12
13
14
15
16
17
18
19
20


def parse_args():
    parser = argparse.ArgumentParser(
        description="Renders dynamo Dockerfiles from templates"
    )
    parser.add_argument(
        "--framework",
        type=str,
        default="vllm",
21
22
        choices=["dynamo", "vllm", "sglang", "trtllm"],
        help="Dockerfile framework to use",
23
    )
24
25
26
27
28
29
30
31
32

    parser.add_argument(
        "--device",
        type=str,
        default="cuda",
        choices=["cuda", "xpu"],
        help="Dockerfile device to use",
    )

33
34
35
36
37
38
39
40
41
42
    parser.add_argument(
        "--target",
        type=str,
        default="runtime",
        help="Dockerfile target to use. Non-exhaustive examples: [runtime, dev, local-dev]",
    )
    parser.add_argument(
        "--platform",
        type=str,
        default="amd64",
43
        help="Dockerfile platform to use. [amd64, arm64]",
44
45
46
47
48
    )
    parser.add_argument(
        "--cuda-version",
        type=str,
        default="12.9",
49
50
        choices=["12.9", "13.0", "13.1"],
        help="CUDA version to use. [12.9 or 13.0 for vllm and sglang, 13.1 for trtllm]",
51
52
53
    )
    parser.add_argument("--make-efa", action="store_true", help="Enable AWS EFA")
    parser.add_argument(
54
        "--output-short-filename",
55
        action="store_true",
56
        help="Output filename is rendered.Dockerfile instead of <framework>-<target>-cuda<cuda_version>-<arch>-rendered.Dockerfile",
57
58
59
60
61
62
63
64
65
66
67
    )
    parser.add_argument(
        "--show-result",
        action="store_true",
        help="Prints the rendered Dockerfile to stdout.",
    )
    args = parser.parse_args()
    return args


def validate_args(args):
68
    valid_inputs = {
69
        "vllm": {
70
            "device": ["cuda", "xpu"],
Dmitry Tokarev's avatar
Dmitry Tokarev committed
71
72
73
74
75
76
77
78
            "target": [
                "runtime",
                "dev",
                "local-dev",
                "framework",
                "wheel_builder",
                "base",
            ],
79
80
81
            "cuda_version": ["12.9", "13.0"],
        },
        "trtllm": {
82
            "device": ["cuda"],
Dmitry Tokarev's avatar
Dmitry Tokarev committed
83
84
85
86
87
88
89
90
            "target": [
                "runtime",
                "dev",
                "local-dev",
                "framework",
                "wheel_builder",
                "base",
            ],
91
92
93
            "cuda_version": ["13.1"],
        },
        "sglang": {
94
            "device": ["cuda"],
95
96
97
98
99
100
101
            "target": [
                "runtime",
                "dev",
                "local-dev",
                "wheel_builder",
                "base",
            ],
102
103
104
            "cuda_version": ["12.9", "13.0"],
        },
        "dynamo": {
105
            "device": ["cuda"],
Dmitry Tokarev's avatar
Dmitry Tokarev committed
106
107
108
109
110
111
112
113
            "target": [
                "runtime",
                "dev",
                "local-dev",
                "frontend",
                "wheel_builder",
                "base",
            ],
114
115
            "cuda_version": ["12.9", "13.0"],
        },
116
117
118
    }

    if args.framework in valid_inputs:
Dmitry Tokarev's avatar
Dmitry Tokarev committed
119
120
121
        if (
            args.target in valid_inputs[args.framework]["target"]
            and args.cuda_version in valid_inputs[args.framework]["cuda_version"]
122
            and args.device in valid_inputs[args.framework]["device"]
Dmitry Tokarev's avatar
Dmitry Tokarev committed
123
        ):
124
            return
125

126
        raise ValueError(
127
            f"Invalid input combination: [framework={args.framework},target={args.target},cuda_version={args.cuda_version},device={args.device}]"
128
        )
129
130

    raise ValueError(
131
        f"Invalid input combination: [framework={args.framework},target={args.target},cuda_version={args.cuda_version},device={args.device}]"
132
    )
133
134
135
136


def render(args, context, script_dir):
    env = Environment(
137
138
139
140
        loader=FileSystemLoader(script_dir),
        trim_blocks=False,
        lstrip_blocks=True,
        undefined=StrictUndefined,  # Raise an error if a variable in the template is not provided in the context
141
142
143
144
145
    )
    template = env.get_template("Dockerfile.template")
    rendered = template.render(
        context=context,
        framework=args.framework,
146
        device=args.device,
147
148
149
150
151
152
153
154
        target=args.target,
        platform=args.platform,
        cuda_version=args.cuda_version,
        make_efa=args.make_efa,
    )
    # Replace all instances of 3+ newlines with 2 newlines
    cleaned = re.sub(r"\n{3,}", "\n\n", rendered)

155
    if args.output_short_filename:
156
157
        filename = "rendered.Dockerfile"
    else:
158
        filename = f"{args.framework}-{args.target}-{args.device}{args.cuda_version}-{args.platform}-rendered.Dockerfile"
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    with open(f"{script_dir}/{filename}", "w") as f:
        f.write(cleaned)

    if args.show_result:
        print("##############")
        print("# Dockerfile #")
        print("##############")
        print(cleaned)
        print("##############")

    print(f"INFO: Generated Dockerfile written to {script_dir}/{filename}")

    return


def main():
    args = parse_args()
    validate_args(args)
178
179
180
    # Clear cuda version for non-cuda device
    if args.device != "cuda":
        args.cuda_version = ""
181
    script_dir = Path(__file__).parent
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    with open(f"{script_dir}/context.yaml", "r") as f:
        context = yaml.safe_load(f)

    render(args, context, script_dir)

    if args.target == "local-dev":
        print(
            "INFO: Remember to add --build-arg values for USER_UID and USER_GID when building a local-dev image!"
        )
        print(
            "      Recommendation: --build-arg USER_UID=$(id -u) --build-arg USER_GID=$(id -g)"
        )


if __name__ == "__main__":
    main()