test_pipeline_infer.py 2.98 KB
Newer Older
1
2
3
4
import pytest
import torch
import torch.distributed as dist
import transformers
5
from packaging import version
6
7

import colossalai
8
from colossalai.inference import CaiInferEngine, LlamaModelInferPolicy
9
10
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn

11
12
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")

13
14
15
16
17
18
19
20
21

def data_gen():
    input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
    attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
    return dict(input_ids=input_ids, attention_mask=attention_mask)


inputs = data_gen()
for k, v in inputs.items():
22
    if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
23
24
        new_shape = [1] * v.dim()
        new_shape[0] = 16
25
        inputs[k] = v.to("cuda").repeat(*new_shape)
26

27

28
29
30
31
32
33
def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
    model = transformers.LlamaForCausalLM(
        transformers.LlamaConfig(
            vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
        )
    )
34

35
36
    engine = CaiInferEngine(
        tp_size=tp_size,
37
38
        pp_size=pp_size,
        model=model,
39
        model_policy=LlamaModelInferPolicy(),
40
        max_output_len=max_output_len,
41
42
        micro_batch_size=micro_batch_size,
    )
43
    output = engine.inference(inputs)
44
    if dist.get_rank() == 0:
45
        assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
46
47


48
@parameterize("tp_size", [1])
49
@parameterize("pp_size", [2])
50
51
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
52
@clear_cache_before_run()
53
54
55
56
57
58
59
60
61
62
63
64
def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
    pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
    torch.cuda.empty_cache()


@parameterize("tp_size", [2])
@parameterize("pp_size", [2])
@parameterize("max_output_len", [4])
@parameterize("micro_batch_size", [1])
@clear_cache_before_run()
def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
    pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
65
66
67
68
    torch.cuda.empty_cache()


def check_pipeline_inference(rank, world_size, port):
69
    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
70
71
72
    run_pipeline_inference_test()


73
74
75
76
77
def check_tp_pipeline_inference(rank, world_size, port):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    run_tp_pipeline_inference_test()


78
@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
79
80
81
82
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
83
    spawn(check_pipeline_inference, nprocs=2)
84
    spawn(check_tp_pipeline_inference, nprocs=4)
85
86


87
if __name__ == "__main__":
88
    test_pipeline_inference()