pin_rocm_dependencies.py 7.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Pin vLLM dependencies to exact versions of custom ROCm wheels.

This script modifies vLLM's requirements files to replace version constraints
with exact versions of custom-built ROCm wheels (torch, triton, torchvision, amdsmi).

This ensures that 'pip install vllm' automatically installs the correct custom wheels
instead of allowing pip to download different versions from PyPI.
"""

import sys
from pathlib import Path

17
18
import regex as re

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

def extract_version_from_wheel(wheel_name: str) -> str:
    """
    Extract version from wheel filename.

    Example:
        torch-2.9.0a0+git1c57644-cp312-cp312-linux_x86_64.whl -> 2.9.0a0+git1c57644
        triton-3.4.0-cp312-cp312-linux_x86_64.whl -> 3.4.0
    """
    # Wheel format:
    #    {distribution}-{version}(-{build tag})?-{python}-{abi}-{platform}.whl
    parts = wheel_name.replace(".whl", "").split("-")

    if len(parts) < 5:
        raise ValueError(f"Invalid wheel filename format: {wheel_name}")

    # Version is the second part
    version = parts[1]
    return version


def get_custom_wheel_versions(install_dir: str) -> dict[str, str]:
    """
    Read /install directory and extract versions of custom wheels.

    Returns:
        Dict mapping package names to exact versions
    """
    install_path = Path(install_dir)
    if not install_path.exists():
        print(f"ERROR: Install directory not found: {install_dir}", file=sys.stderr)
        sys.exit(1)

    versions = {}

    # Map wheel prefixes to package names
    # IMPORTANT: Use dashes to avoid matching substrings
    #            (e.g., 'torch' would match 'torchvision')
    # ORDER MATTERS: This order is preserved when pinning dependencies
    #               in requirements files
    package_mapping = [
        ("torch-", "torch"),  # Match torch- (not torchvision)
        ("triton-", "triton"),  # Match triton- (not triton_kernels)
        ("triton_kernels-", "triton-kernels"),  # Match triton_kernels-
        ("torchvision-", "torchvision"),  # Match torchvision-
        ("torchaudio-", "torchaudio"),  # Match torchaudio-
        ("amdsmi-", "amdsmi"),  # Match amdsmi-
        ("flash_attn-", "flash-attn"),  # Match flash_attn-
67
        ("amd_aiter-", "amd-aiter"),  # Match amd_aiter-
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    ]

    for wheel_file in install_path.glob("*.whl"):
        wheel_name = wheel_file.name

        for prefix, package_name in package_mapping:
            if wheel_name.startswith(prefix):
                try:
                    version = extract_version_from_wheel(wheel_name)
                    versions[package_name] = version
                    print(f"Found {package_name}=={version}", file=sys.stderr)
                except Exception as e:
                    print(
                        f"WARNING: Could not extract version from {wheel_name}: {e}",
                        file=sys.stderr,
                    )
                break

    # Return versions in the order defined by package_mapping
    ordered_versions = {}
    for _, package_name in package_mapping:
        if package_name in versions:
            ordered_versions[package_name] = versions[package_name]
    return ordered_versions


def pin_dependencies_in_requirements(requirements_path: str, versions: dict[str, str]):
    """
    Insert custom wheel pins at the TOP of requirements file.

    This ensures that when setup.py processes the file line-by-line,
    custom wheels (torch, triton, etc.) are encountered FIRST, before
    any `-r common.txt` includes that might pull in other dependencies.

    Creates:
        # Custom ROCm wheel pins (auto-generated)
        torch==2.9.0a0+git1c57644
        triton==3.4.0
        torchvision==0.23.0a0+824e8c8
        amdsmi==26.1.0+5df6c765

        -r common.txt
        ... rest of file ...
    """
    requirements_file = Path(requirements_path)

    if not requirements_file.exists():
        print(
            f"ERROR: Requirements file not found: {requirements_path}", file=sys.stderr
        )
        sys.exit(1)

    # Backup original file
    backup_file = requirements_file.with_suffix(requirements_file.suffix + ".bak")
    with open(requirements_file) as f:
        original_lines = f.readlines()

    # Write backup
    with open(backup_file, "w") as f:
        f.writelines(original_lines)

    # Build header with pinned custom wheels
    header_lines = [
        "# Custom ROCm wheel pins (auto-generated by pin_rocm_dependencies.py)\n",
        "# These must come FIRST to ensure correct dependency resolution\n",
    ]

    for package_name, exact_version in versions.items():
        header_lines.append(f"{package_name}=={exact_version}\n")

    header_lines.append("\n")  # Blank line separator

    # Filter out any existing entries for custom packages from original file
    filtered_lines = []
    removed_packages = []

    for line in original_lines:
        stripped = line.strip()
        should_keep = True

        # Check if this line is for one of our custom packages
        if stripped and not stripped.startswith("#") and not stripped.startswith("-"):
            for package_name in versions:
                # Handle both hyphen and underscore variations
                pattern_name = package_name.replace("-", "[-_]")
                pattern = rf"^{pattern_name}\s*[=<>]=?\s*[\d.a-zA-Z+]+"

                if re.match(pattern, stripped, re.IGNORECASE):
                    removed_packages.append(f"{package_name}: {stripped}")
                    should_keep = False
                    break

        if should_keep:
            filtered_lines.append(line)

    # Combine: header + filtered original content
    final_lines = header_lines + filtered_lines

    # Write modified content
    with open(requirements_file, "w") as f:
        f.writelines(final_lines)

    # Print summary
    print("\n✓ Inserted custom wheel pins at TOP of requirements:", file=sys.stderr)
    for package_name, exact_version in versions.items():
        print(f"  - {package_name}=={exact_version}", file=sys.stderr)

    if removed_packages:
        print("\n✓ Removed old package entries:", file=sys.stderr)
        for pkg in removed_packages:
            print(f"  - {pkg}", file=sys.stderr)

    print(f"\n✓ Patched requirements file: {requirements_path}", file=sys.stderr)
    print(f"  Backup saved: {backup_file}", file=sys.stderr)


def main():
    if len(sys.argv) != 3:
        print(
            f"Usage: {sys.argv[0]} <install_dir> <requirements_file>", file=sys.stderr
        )
        print(
            f"Example: {sys.argv[0]} /install /app/vllm/requirements/rocm.txt",
            file=sys.stderr,
        )
        sys.exit(1)

    install_dir = sys.argv[1]
    requirements_path = sys.argv[2]

    print("=" * 70, file=sys.stderr)
    print("Pinning vLLM dependencies to custom ROCm wheel versions", file=sys.stderr)
    print("=" * 70, file=sys.stderr)

    # Get versions from custom wheels
    print(f"\nScanning {install_dir} for custom wheels...", file=sys.stderr)
    versions = get_custom_wheel_versions(install_dir)

    if not versions:
        print("\nERROR: No custom wheels found in /install!", file=sys.stderr)
        sys.exit(1)

    # Pin dependencies in requirements file
    print(f"\nPatching {requirements_path}...", file=sys.stderr)
    pin_dependencies_in_requirements(requirements_path, versions)

    print("\n" + "=" * 70, file=sys.stderr)
    print("✓ Dependency pinning complete!", file=sys.stderr)
    print("=" * 70, file=sys.stderr)

    sys.exit(0)


if __name__ == "__main__":
    main()