ompmultiprocessing.py 5.73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""OMP Aware Multiprocessing manager for running multiprocessing.Process()
Copyright (c) 2026 Red Hat Inc
Copyright (c) 2026 Cambridge Greys Ltd
"""

import json
import os
import subprocess


def _int(arg):
    """Relaxed parsing of ints which handles a - instead of a number.
    The lscpu json may contain that for nodes in some cases. If that
    is the case we parse it to zero
    """
    try:
        if int(arg) >= 0:
            return int(arg)
    except ValueError:
        pass
    return 0


def parse_mask(mask):
    """Expand a X-Y,Z list"""
    result = []
    for token in mask.split(","):
        try:
            start, finish = token.split("-")
            if int(start) > int(finish):
                raise IndexError("Invalid Indexes for cpu ranges")
            for cpu in range(int(start), int(finish) + 1):
                result.append(cpu)
        except ValueError:
            result.append(int(token))
    return set(result)


def enumerate_resources(resource_map, mask=None, allowed=None):
    """Enumerate system resources"""
    if allowed is None:
        allowed = os.sched_getaffinity(0)
    if mask is not None:
        allowed = allowed & mask

    try:
        allowed_nodes = parse_mask(os.environ["CPU_VISIBLE_MEMORY_NODES"])
    except KeyError:
        allowed_nodes = None

    lscpu: dict[str, dict] = {"cpus": {}, "cores": {}, "nodes": {}}
    for cpu in resource_map["cpus"]:
        cpunum = int(cpu["cpu"])
        if (
            cpunum in allowed
            and cpunum >= 0
            and (allowed_nodes is None or _int(cpu["node"]) in allowed_nodes)
        ):
            lscpu["cpus"][cpunum] = [cpu]
            core = _int(cpu["core"])
            if lscpu["cores"].get(core, None) is None:
                lscpu["cores"][core] = [cpu]
            else:
                lscpu["cores"][core].append(cpu)
            node = _int(cpu["node"])
            if lscpu["nodes"].get(node, None) is None:
                lscpu["nodes"][node] = [cpu]
            else:
                lscpu["nodes"][node].append(cpu)
    return lscpu


def produce_cpu_list(cpus, smt=1):
    """Produce a CPU list with/without SMT pairs - main cpu list case"""
    mask: list[int] = []
    for key, value in cpus.items():
        exists = 0
        for cpu in mask:
            if cpu == value[0]["core"]:
                exists += 1
                break
        if exists < smt:
            mask.append(int(key))
    return {"mask": set(mask), "available": True}


def produce_cpu_sublist(scpus, smt=1):
    """Produce a CPU list with/without SMT pairs - resource leaf case"""
    cpu_list: list[dict] = []
    for value in scpus:
        exists = 0
        for cpu in cpu_list:
            if int(cpu["core"]) == int(value["core"]):
                exists += 1
                break
        if exists < smt:
            cpu_list.append(value)
    mask = []
    for cpu in cpu_list:
        mask.append(int(cpu["cpu"]))

    return {"mask": set(mask), "available": True}


def create_omp_places(resources, strategy, smt=True):
    """Parse CPU topology and generate possible CPU masks"""
    omp_places = []
    if strategy == "all":
        omp_places.append(produce_cpu_list(resources["cpus"], smt))
    elif strategy == "cores":
        for value in resources["cores"].values():
            omp_places.append(produce_cpu_sublist(value, smt))
    elif strategy == "nodes":
        for value in resources["nodes"].values():
            omp_places.append(produce_cpu_sublist(value, smt))
    else:
        raise NotImplementedError("Unknown strategy")

    return omp_places


# pylint: disable=too-few-public-methods
class OMPProcessManager:
    """OMP aware wrapper to run mp Process()"""

    def __init__(self, strategy="nodes", smt=1, mock=None, affinity=None):
        self.strategy = strategy
        self.smt = smt
        self.omp_places = []
        vllm_mask = os.environ.get("VLLM_CPU_OMP_THREADS_BIND", None)
        self.setup_omp = vllm_mask != "nobind"
        if self.setup_omp:
            omp_places = []
            if vllm_mask is not None:
                masks = []
                for spec in vllm_mask.split("|"):
                    masks.append(parse_mask(spec))
            else:
                masks = [None]
            if mock is None:
                data = subprocess.run(
                    ["lscpu", "-Je"], check=True, capture_output=True
                ).stdout
            else:
                with open(mock, mode="rb") as jf:
                    data = jf.read()
            lscpu = json.loads(data)
            for mask in masks:
                resources = enumerate_resources(lscpu, mask, affinity)
                omp_places.extend(create_omp_places(resources, strategy, smt))
            self.omp_places = sorted(
                omp_places,
                key=lambda p: "{:04d}-{:04d}".format(len(p["mask"]), max(p["mask"])),
                reverse=True,
            )

    def run(self, what, *args, **kwargs):
        """Run arg with correct OMP environment"""
        if self.setup_omp:
            for place in self.omp_places:
                if place["available"]:
                    reserve = int(os.environ.get("VLLM_CPU_NUM_OF_RESERVED_CPU", 0))
                    place["available"] = False
                    # pylint: disable=consider-using-f-string
                    os.environ["OMP_PLACES"] = "{}".format(place["mask"])
                    os.environ["OMP_NUM_THREADS"] = "{}".format(
                        len(place["mask"]) - reserve
                    )
                    os.environ["OMP_PROC_BIND"] = "TRUE"
                    return what(*args, **kwargs)
            raise IndexError("Out of OMP places")
        return what(*args, **kwargs)