test_torch_tp.py 631 Bytes
Newer Older
1
2
import unittest

3
from sglang.test.test_utils import is_in_ci, run_bench_one_batch
4
5
6
7


class TestTorchTP(unittest.TestCase):
    def test_torch_native_llama(self):
8
        output_throughput = run_bench_one_batch(
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
            "meta-llama/Meta-Llama-3-8B",
            [
                "--tp",
                "2",
                "--json-model-override-args",
                '{"architectures": ["TorchNativeLlamaForCausalLM"]}',
                "--disable-cuda-graph",
            ],
        )

        if is_in_ci():
            assert output_throughput > 0, f"{output_throughput=}"


if __name__ == "__main__":
    unittest.main()