benchmark.py 3.43 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import time
import torch
import argparse
import numpy as np
import pandas as pd
from awq import AutoAWQForCausalLM

def warmup(model):
    warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device)
    torch.mm(warm_up,warm_up)

def generate(model, input_ids, n_generate):
    context_time = 0
    generate_time = 0

    with torch.inference_mode():
        for i in range(n_generate):
            torch.cuda.synchronize()
            start = time.time()

            if i == 0:
                # prefill context
                inputs = torch.as_tensor([input_ids], device=next(model.parameters()).device)
            else:
                # decode tokens
                inputs = torch.as_tensor([[token]], device=next(model.parameters()).device)
            
            out = model(inputs, use_cache=True)

            torch.cuda.synchronize()
            token = out[0][:, -1].max(1)[1].unsqueeze(1)

            if i == 0:
                context_time += time.time() - start
            else:
                generate_time += time.time() - start
    
    return context_time, generate_time

def run_round(model_path, quant_file, n_generate, input_ids):
    print(f" -- Loading model...")
    model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True)

    print(f" -- Warming up...")
    warmup(model)

    print(f" -- Generating {n_generate} tokens, {len(input_ids)} token prompt...")
    context_time, generate_time = generate(model, input_ids, n_generate)

Casper Hansen's avatar
Casper Hansen committed
50
    device = next(model.parameters()).device
Casper Hansen's avatar
Casper Hansen committed
51
52
    prefill_tokens_per_second = n_generate / context_time
    decode_tokens_per_second = n_generate / generate_time
Casper Hansen's avatar
Casper Hansen committed
53
54
    memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
    memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100
Casper Hansen's avatar
Casper Hansen committed
55
56
57

    print(f" ** Speed (Prefill): {prefill_tokens_per_second:.2f} tokens/second")
    print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second")
Casper Hansen's avatar
Casper Hansen committed
58
    print(f" ** Max Memory (VRAM): {memory_used:.2f} GB ({memory_pct:.2f}%)")
Casper Hansen's avatar
Casper Hansen committed
59
60

    return {
Casper Hansen's avatar
Casper Hansen committed
61
62
        "Prefill length": len(input_ids),
        "Decode length": n_generate,
Casper Hansen's avatar
Casper Hansen committed
63
64
        "Prefill tokens/s": prefill_tokens_per_second,
        "Decode tokens/s": decode_tokens_per_second,
Casper Hansen's avatar
Casper Hansen committed
65
        "Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)",
Casper Hansen's avatar
Casper Hansen committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        "GPU": torch.cuda.get_device_name()
    }

def main(args):
    rounds = [
        {"context": 4, "n_generate": 200},
        {"context": 32, "n_generate": 32},
        {"context": 64, "n_generate": 64},
        {"context": 128, "n_generate": 128},
        {"context": 256, "n_generate": 256},
        {"context": 512, "n_generate": 512},
        {"context": 1024, "n_generate": 1024},
        {"context": 2048, "n_generate": 2048},
    ]

    all_stats = []

    for settings in rounds:
        input_ids = [1 for _ in range(settings["context"])]

        stats = run_round(
            args.model_path,
            args.quant_file,
            settings["n_generate"],
            input_ids
        )
        
        all_stats.append(stats)
    
    df = pd.DataFrame(all_stats)
    print(df.to_markdown(index=False))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="vicuna-7b-v1.5-awq-gemv", help="path to the model")
    parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename")
    args = parser.parse_args()

    main(args)