test_binding.py 630 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import re
import unittest

import sgl_kernel
import torch

kernel = torch.ops.sgl_kernel

from sglang.test.test_utils import CustomTestCase


class TestGemm(CustomTestCase):
    def test_binding(self):
        start_id = 1
        n_cpu = 6

        expected_cores = list(map(str, range(start_id, start_id + n_cpu)))
        cpu_ids = ",".join(expected_cores)
        output = kernel.init_cpu_threads_env(cpu_ids)

        bindings = re.findall(r"OMP tid: \d+, core (\d+)", output)
        self.assertEqual(len(bindings), n_cpu)

        self.assertEqual(bindings, expected_cores)


if __name__ == "__main__":
    unittest.main()