test_torch_tp.py 784 Bytes
Newer Older
1
2
import unittest

3
4
5
6
7
8
from sglang.test.test_utils import (
    DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
    CustomTestCase,
    is_in_ci,
    run_bench_offline_throughput,
)
9
10


11
class TestTorchTP(CustomTestCase):
12
    def test_torch_native_llama(self):
13
14
        output_throughput = run_bench_offline_throughput(
            DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
15
16
17
            [
                "--tp",
                "2",
18
19
20
                # This cannot run anymore with the new torch version.
                # "--json-model-override-args",
                # '{"architectures": ["TorchNativeLlamaForCausalLM"]}',
21
22
23
24
25
                "--disable-cuda-graph",
            ],
        )

        if is_in_ci():
26
            self.assertGreater(output_throughput, 0)
27
28
29
30


if __name__ == "__main__":
    unittest.main()