test_hybrid_bloom.py 3.95 KB
Newer Older
1
2
import importlib.util

3
4
5
6
import pytest
import torch
import torch.distributed as dist
import transformers
7
from packaging import version
8
9

import colossalai
10
from colossalai.inference import InferenceEngine
11
12
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

13
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
14
15
16
17
HAS_LIGHTLLM_KERNEL = True

if importlib.util.find_spec("lightllm") is None:
    HAS_LIGHTLLM_KERNEL = False
18

19
20
21
22
23
24
25
26
27

def data_gen():
    input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
    attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
    return dict(input_ids=input_ids, attention_mask=attention_mask)


inputs = data_gen()
for k, v in inputs.items():
28
    if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
29
30
        new_shape = [1] * v.dim()
        new_shape[0] = 16
31
        inputs[k] = v.to("cuda").repeat(*new_shape)
32

33

34
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
35
36
    model = transformers.BloomForCausalLM(
        transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4)
37
    )
38

39
    engine = InferenceEngine(
40
        tp_size=tp_size,
41
42
        pp_size=pp_size,
        model=model,
43
        max_output_len=max_output_len,
44
45
        micro_batch_size=micro_batch_size,
    )
46
    output = engine.generate(inputs)
47
    if dist.get_rank() == 0:
48
        assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
49
50


51
@parameterize("tp_size", [1])
52
@parameterize("pp_size", [2])
53
54
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
55
@clear_cache_before_run()
56
57
58
59
60
61
62
63
64
65
66
67
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
    pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
    torch.cuda.empty_cache()


@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
    pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
68
69
70
    torch.cuda.empty_cache()


71
72
73
74
75
76
77
78
@parameterize("tp_size", [2])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
    pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
    torch.cuda.empty_cache()
79
80


81
82
83
84
85
86
87
88
89
90
91
@parameterize("tp_size", [1])
@parameterize("pp_size", [1])
@parameterize("max_output_len", [2])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
    pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
    torch.cuda.empty_cache()


def check_tp_pp_inference(rank, world_size, port):
92
93
94
95
    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    run_tp_pipeline_inference_test()


96
def check_tp_or_pp_inference(rank, world_size, port):
97
98
99
100
101
    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    run_tp_inference_test()
    run_pipeline_inference_test()


102
103
104
105
106
def check_single_inference(rank, world_size, port):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    run_single_inference_test


107
108
109
110
@pytest.mark.skipif(
    not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
    reason="kv-cache manager engine requires cuda version to be higher than 11.5",
)
111
112
113
114
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
115
116
117
    spawn(check_tp_pp_inference, nprocs=4)
    spawn(check_tp_or_pp_inference, nprocs=2)
    spawn(check_single_inference, nprocs=1)
118
119


120
if __name__ == "__main__":
121
    test_pipeline_inference()