test_pipeline_parallel.py 4.28 KB
Newer Older
1
import pytest
2
from transformers import AutoTokenizer
3

4
from ..utils import RemoteOpenAIServer
5
6


7
@pytest.mark.parametrize(
8
    "TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
9
10
11
        (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
        (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
        (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
12
13
        (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
        (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
14
15
    ])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
16
17
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

18
    pp_args = [
19
20
21
22
23
24
25
26
27
28
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--pipeline-parallel-size",
        str(PP_SIZE),
        "--tensor-parallel-size",
        str(TP_SIZE),
        "--distributed-executor-backend",
        "ray",
    ]
29
30
31
32
33
34
35
36
37
38
39

    # compare without pipeline parallelism
    # NOTE: use mp backend for TP
    # PP tests might involve multiple nodes, and ray might
    #  schedule all workers in a node other than the head node,
    #  which can cause the test to fail.
    tp_args = [
        # use half precision for speed and memory savings in CI environment
        "--dtype",
        "bfloat16",
        "--tensor-parallel-size",
40
        str(max(TP_SIZE, 2)),  # We only use 2 GPUs in the CI.
41
42
43
        "--distributed-executor-backend",
        "mp",
    ]
44
    if CHUNKED_PREFILL:
45
46
        pp_args.append("--enable-chunked-prefill")
        tp_args.append("--enable-chunked-prefill")
47
    if EAGER_MODE:
48
49
50
        pp_args.append("--enforce-eager")
        tp_args.append("--enforce-eager")

51
52
    prompt = "Hello, my name is"
    token_ids = tokenizer(prompt)["input_ids"]
53
    results = []
54
    for args in (pp_args, tp_args):
55
        with RemoteOpenAIServer(MODEL_NAME, args) as server:
56
57
58
59
60
61
62
63
64
65
66
67
68
69
            client = server.get_client()

            # test models list
            models = client.models.list()
            models = models.data
            served_model = models[0]
            results.append({
                "test": "models_list",
                "id": served_model.id,
                "root": served_model.root,
            })

            # test with text prompt
            completion = client.completions.create(model=MODEL_NAME,
70
                                                   prompt=prompt,
71
72
73
74
75
76
77
78
79
80
81
82
83
                                                   max_tokens=5,
                                                   temperature=0.0)

            results.append({
                "test": "single_completion",
                "text": completion.choices[0].text,
                "finish_reason": completion.choices[0].finish_reason,
                "usage": completion.usage,
            })

            # test using token IDs
            completion = client.completions.create(
                model=MODEL_NAME,
84
                prompt=token_ids,
85
86
87
88
89
90
91
92
93
94
95
96
97
98
                max_tokens=5,
                temperature=0.0,
            )

            results.append({
                "test": "token_ids",
                "text": completion.choices[0].text,
                "finish_reason": completion.choices[0].finish_reason,
                "usage": completion.usage,
            })

            # test simple list
            batch = client.completions.create(
                model=MODEL_NAME,
99
                prompt=[prompt, prompt],
100
101
102
103
104
105
106
107
108
109
110
111
112
                max_tokens=5,
                temperature=0.0,
            )

            results.append({
                "test": "simple_list",
                "text0": batch.choices[0].text,
                "text1": batch.choices[1].text,
            })

            # test streaming
            batch = client.completions.create(
                model=MODEL_NAME,
113
                prompt=[prompt, prompt],
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
                max_tokens=5,
                temperature=0.0,
                stream=True,
            )
            texts = [""] * 2
            for chunk in batch:
                assert len(chunk.choices) == 1
                choice = chunk.choices[0]
                texts[choice.index] += choice.text
            results.append({
                "test": "streaming",
                "texts": texts,
            })

    n = len(results) // 2
    pp_results = results[:n]
    tp_results = results[n:]
    for pp, tp in zip(pp_results, tp_results):
        assert pp == tp