bench_bloom.py 3.84 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import argparse
import os
import time

import torch
from transformers import BloomForCausalLM, BloomTokenizerFast

import colossalai
from colossalai.inference.tensor_parallel.engine import TPInferEngine
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn

14
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
15
16
17
18
19
20
21
22
23
24
25
26
27


def print_perf_stats(latency_set, config, bs, warmup=3):
    # trim warmup queries
    latency_set = list(latency_set)
    latency_set = latency_set[warmup:]
    count = len(latency_set)

    if count > 0:
        latency_set.sort()
        avg = sum(latency_set) / count
        num_layers = getattr(config, "num_layers", config.num_hidden_layers)
        num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
28
        num_bytes = 2  # float16
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

        print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
        print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
        print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
        print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))


def bench_bloom(args):
    model_path = args.path
    max_batch_size = args.batch_size
    max_input_len = args.input_len
    max_output_len = args.output_len

    tokenizer = BloomTokenizerFast.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
    model = model.half()

    # init TPInferEngine and shard the original model
    # To benchmark torch original, comment out the line of optimizing model
    shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
    infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)

    # prepare data for generation
    generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
    input_tokens = {
        "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
56
        "attention_mask": torch.ones((max_batch_size, max_input_len)),
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    }
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
            print(f" input_tokens[{t}].shape: {input_tokens[t].shape}")

    iters = 10
    times = []
    for i in range(iters):
        torch.cuda.synchronize()
        start = time.time()
        outputs = infer_engine.generate(input_tokens, **generate_kwargs)
        torch.cuda.synchronize()
        end = time.time()
        out_len = outputs.shape[1]
        print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
        times.append((end - start) / (out_len - max_input_len))

    print_perf_stats(times, model.config, max_batch_size)


def check_bloom(rank, world_size, port, args):
    disable_existing_loggers()
80
    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
81
82
83
84
85
86
87
88
89
90
91
    bench_bloom(args)


@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom(args):
    spawn(check_bloom, args.tp_size, args=args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
92
93
94
95
96
    parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
    parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
    parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
    parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
    parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
97
98
99
100

    args = parser.parse_args()

    test_bloom(args)