render.py 5.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import re
from pathlib import Path

import yaml
10
from jinja2 import Environment, FileSystemLoader, StrictUndefined
11
12
13
14
15
16
17
18
19
20


def parse_args():
    parser = argparse.ArgumentParser(
        description="Renders dynamo Dockerfiles from templates"
    )
    parser.add_argument(
        "--framework",
        type=str,
        default="vllm",
21
22
        choices=["dynamo", "vllm", "sglang", "trtllm"],
        help="Dockerfile framework to use",
23
24
25
26
27
28
29
30
31
32
33
    )
    parser.add_argument(
        "--target",
        type=str,
        default="runtime",
        help="Dockerfile target to use. Non-exhaustive examples: [runtime, dev, local-dev]",
    )
    parser.add_argument(
        "--platform",
        type=str,
        default="amd64",
34
        help="Dockerfile platform to use. [amd64, arm64]",
35
36
37
38
39
    )
    parser.add_argument(
        "--cuda-version",
        type=str,
        default="12.9",
40
41
        choices=["12.9", "13.0", "13.1"],
        help="CUDA version to use. [12.9 or 13.0 for vllm and sglang, 13.1 for trtllm]",
42
43
44
    )
    parser.add_argument("--make-efa", action="store_true", help="Enable AWS EFA")
    parser.add_argument(
45
        "--output-short-filename",
46
        action="store_true",
47
        help="Output filename is rendered.Dockerfile instead of <framework>-<target>-cuda<cuda_version>-<arch>-rendered.Dockerfile",
48
49
50
51
52
53
54
55
56
57
58
    )
    parser.add_argument(
        "--show-result",
        action="store_true",
        help="Prints the rendered Dockerfile to stdout.",
    )
    args = parser.parse_args()
    return args


def validate_args(args):
59
    valid_inputs = {
60
        "vllm": {
Dmitry Tokarev's avatar
Dmitry Tokarev committed
61
62
63
64
65
66
67
68
            "target": [
                "runtime",
                "dev",
                "local-dev",
                "framework",
                "wheel_builder",
                "base",
            ],
69
70
71
            "cuda_version": ["12.9", "13.0"],
        },
        "trtllm": {
Dmitry Tokarev's avatar
Dmitry Tokarev committed
72
73
74
75
76
77
78
79
            "target": [
                "runtime",
                "dev",
                "local-dev",
                "framework",
                "wheel_builder",
                "base",
            ],
80
81
82
            "cuda_version": ["13.1"],
        },
        "sglang": {
83
84
85
86
87
88
89
            "target": [
                "runtime",
                "dev",
                "local-dev",
                "wheel_builder",
                "base",
            ],
90
91
92
            "cuda_version": ["12.9", "13.0"],
        },
        "dynamo": {
Dmitry Tokarev's avatar
Dmitry Tokarev committed
93
94
95
96
97
98
99
100
            "target": [
                "runtime",
                "dev",
                "local-dev",
                "frontend",
                "wheel_builder",
                "base",
            ],
101
102
            "cuda_version": ["12.9", "13.0"],
        },
103
104
105
    }

    if args.framework in valid_inputs:
Dmitry Tokarev's avatar
Dmitry Tokarev committed
106
107
108
109
        if (
            args.target in valid_inputs[args.framework]["target"]
            and args.cuda_version in valid_inputs[args.framework]["cuda_version"]
        ):
110
            return
111
112
113
        raise ValueError(
            f"Invalid input combination: [framework={args.framework},target={args.target},cuda_version={args.cuda_version}]"
        )
114
115

    raise ValueError(
116
        f"Invalid input combination: [framework={args.framework},target={args.target},cuda_version={args.cuda_version}]"
117
    )
118
119
120
121


def render(args, context, script_dir):
    env = Environment(
122
123
124
125
        loader=FileSystemLoader(script_dir),
        trim_blocks=False,
        lstrip_blocks=True,
        undefined=StrictUndefined,  # Raise an error if a variable in the template is not provided in the context
126
127
128
129
130
131
132
133
134
135
136
137
138
    )
    template = env.get_template("Dockerfile.template")
    rendered = template.render(
        context=context,
        framework=args.framework,
        target=args.target,
        platform=args.platform,
        cuda_version=args.cuda_version,
        make_efa=args.make_efa,
    )
    # Replace all instances of 3+ newlines with 2 newlines
    cleaned = re.sub(r"\n{3,}", "\n\n", rendered)

139
    if args.output_short_filename:
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        filename = "rendered.Dockerfile"
    else:
        filename = f"{args.framework}-{args.target}-cuda{args.cuda_version}-{args.platform}-rendered.Dockerfile"

    with open(f"{script_dir}/{filename}", "w") as f:
        f.write(cleaned)

    if args.show_result:
        print("##############")
        print("# Dockerfile #")
        print("##############")
        print(cleaned)
        print("##############")

    print(f"INFO: Generated Dockerfile written to {script_dir}/{filename}")

    return


def main():
    args = parse_args()
    validate_args(args)
162
    script_dir = Path(__file__).parent
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    with open(f"{script_dir}/context.yaml", "r") as f:
        context = yaml.safe_load(f)

    render(args, context, script_dir)

    if args.target == "local-dev":
        print(
            "INFO: Remember to add --build-arg values for USER_UID and USER_GID when building a local-dev image!"
        )
        print(
            "      Recommendation: --build-arg USER_UID=$(id -u) --build-arg USER_GID=$(id -g)"
        )


if __name__ == "__main__":
    main()