test_pipeline_parallel.py 4.31 KB
Newer Older
1
2
import pytest

3
from ..utils import RemoteOpenAIServer
4
5


6
@pytest.mark.parametrize(
7
8
    "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME",
    [
9
10
11
        (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
        (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
        (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
12
13
14
        # TODO: figure out why PP=4 tests are flaky
        # (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
        # (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
15
16
17
    ])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
    pp_args = [
18
19
20
21
22
23
24
25
26
27
28
29
        "--model",
        MODEL_NAME,
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--pipeline-parallel-size",
        str(PP_SIZE),
        "--tensor-parallel-size",
        str(TP_SIZE),
        "--distributed-executor-backend",
        "ray",
    ]
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

    # compare without pipeline parallelism
    # NOTE: use mp backend for TP
    # PP tests might involve multiple nodes, and ray might
    #  schedule all workers in a node other than the head node,
    #  which can cause the test to fail.
    tp_args = [
        "--model",
        MODEL_NAME,
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--tensor-parallel-size",
        str(max(TP_SIZE, 2)),  # use at least TP_SIZE=2 to hold the model
        "--distributed-executor-backend",
        "mp",
    ]
47
    if CHUNKED_PREFILL:
48
49
        pp_args.append("--enable-chunked-prefill")
        tp_args.append("--enable-chunked-prefill")
50
    if EAGER_MODE:
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        pp_args.append("--enforce-eager")
        tp_args.append("--enforce-eager")

    results = []
    for args in [pp_args, tp_args]:
        with RemoteOpenAIServer(args) as server:
            client = server.get_client()

            # test models list
            models = client.models.list()
            models = models.data
            served_model = models[0]
            results.append({
                "test": "models_list",
                "id": served_model.id,
                "root": served_model.root,
            })

            # test with text prompt
            completion = client.completions.create(model=MODEL_NAME,
                                                   prompt="Hello, my name is",
                                                   max_tokens=5,
                                                   temperature=0.0)

            results.append({
                "test": "single_completion",
                "text": completion.choices[0].text,
                "finish_reason": completion.choices[0].finish_reason,
                "usage": completion.usage,
            })

            # test using token IDs
            completion = client.completions.create(
                model=MODEL_NAME,
                prompt=[0, 0, 0, 0, 0],
                max_tokens=5,
                temperature=0.0,
            )

            results.append({
                "test": "token_ids",
                "text": completion.choices[0].text,
                "finish_reason": completion.choices[0].finish_reason,
                "usage": completion.usage,
            })

            # test simple list
            batch = client.completions.create(
                model=MODEL_NAME,
                prompt=["Hello, my name is", "Hello, my name is"],
                max_tokens=5,
                temperature=0.0,
            )

            results.append({
                "test": "simple_list",
                "text0": batch.choices[0].text,
                "text1": batch.choices[1].text,
            })

            # test streaming
            batch = client.completions.create(
                model=MODEL_NAME,
                prompt=["Hello, my name is", "Hello, my name is"],
                max_tokens=5,
                temperature=0.0,
                stream=True,
            )
            texts = [""] * 2
            for chunk in batch:
                assert len(chunk.choices) == 1
                choice = chunk.choices[0]
                texts[choice.index] += choice.text
            results.append({
                "test": "streaming",
                "texts": texts,
            })

    n = len(results) // 2
    pp_results = results[:n]
    tp_results = results[n:]
    for pp, tp in zip(pp_results, tp_results):
        assert pp == tp