benchmark_tokenizer.py 4.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import random
import time
from statistics import mean

from transformers import AutoTokenizer

# CONFIG
TOKENIZER_DIR = (
    "/shared/public/sharing/fait360brew/training/models/meta-llama/Llama-3.2-3B"
)
NUM_TOKENS = 20000  # Each prompt should contain this many tokens
BATCH_SIZES = [1, 2, 4, 8]  # Test different batch sizes
NUM_RUNS = 5  # Number of runs for each batch size to get reliable measurements


def generate_random_prompts(num_prompts, num_tokens, tokenizer):
    """Generate random prompts with specified token count."""
    vocab_size = tokenizer.vocab_size
    all_prompts = []

    print(f"Generating {num_prompts} random prompts with {num_tokens} tokens each...")
    for i in range(num_prompts):
        # Generate random token IDs - this directly gives us the exact token count
        random_token_ids = [
            random.randint(0, vocab_size - 1) for _ in range(num_tokens)
        ]
        random_text = tokenizer.decode(
            random_token_ids, clean_up_tokenization_spaces=True
        )

        prompt = f"Prompt {i}: {random_text}"
        tokens = tokenizer.encode(prompt)
        print(f"  Prompt {i}: {len(tokens)} tokens")
        all_prompts.append(prompt)

    return all_prompts


def benchmark_sequential_vs_batch(prompts, batch_size, tokenizer):
    """Compare sequential vs batch tokenization for a given batch size."""

    # Sequential tokenization using encode()
    sequential_times = []
    for run in range(NUM_RUNS):
        batch_prompts = prompts[:batch_size]  # Use same prompts for fair comparison

47
        start_time = time.perf_counter()
48
49
        for prompt in batch_prompts:
            tokens = tokenizer.encode(prompt)
50
        sequential_time = (time.perf_counter() - start_time) * 1000
51
52
53
54
55
56
57
        sequential_times.append(sequential_time)

    # Batch tokenization using tokenizer()
    batch_times = []
    for run in range(NUM_RUNS):
        batch_prompts = prompts[:batch_size]  # Use same prompts for fair comparison

58
        start_time = time.perf_counter()
59
        tokens = tokenizer(batch_prompts)
60
        batch_time = (time.perf_counter() - start_time) * 1000
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
        batch_times.append(batch_time)

    return {
        "batch_size": batch_size,
        "avg_sequential_ms": mean(sequential_times),
        "avg_batch_ms": mean(batch_times),
        "speedup_factor": (
            mean(sequential_times) / mean(batch_times) if mean(batch_times) > 0 else 0
        ),
        "sequential_runs": sequential_times,
        "batch_runs": batch_times,
    }


def main():
    print("Tokenizer Benchmark: Sequential vs Batch Processing")
    print("-" * 60)
    print(f"Tokenizer: {TOKENIZER_DIR}")
    print(f"Tokens per prompt: {NUM_TOKENS}")
    print(f"Number of runs per batch size: {NUM_RUNS}")
    print("-" * 60)

    # Load tokenizer once for all operations
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)

    # The largest batch size determines how many prompts we need
    max_batch_size = max(BATCH_SIZES)
    all_prompts = generate_random_prompts(max_batch_size, NUM_TOKENS, tokenizer)

    results = []
    print("\nRunning benchmark...")

    for batch_size in BATCH_SIZES:
        print(f"\nBenchmarking batch size: {batch_size}")
        result = benchmark_sequential_vs_batch(all_prompts, batch_size, tokenizer)
        results.append(result)

        print(f"  Sequential tokenization (encode):")
        for i, run_time in enumerate(result["sequential_runs"]):
            print(f"    Run {i+1}: {run_time:.2f} ms")
        print(f"    Average: {result['avg_sequential_ms']:.2f} ms")

        print(f"  Batch tokenization (tokenizer):")
        for i, run_time in enumerate(result["batch_runs"]):
            print(f"    Run {i+1}: {run_time:.2f} ms")
        print(f"    Average: {result['avg_batch_ms']:.2f} ms")

        print(f"  Speedup factor: {result['speedup_factor']:.2f}x")

    print("\n" + "=" * 60)
    print("SUMMARY OF RESULTS")
    print("=" * 60)
    print(
        f"{'Batch Size':<10} {'Sequential (ms)':<18} {'Batch (ms)':<18} {'Speedup':<10}"
    )
    print("-" * 60)

    for result in results:
        print(
            f"{result['batch_size']:<10} {result['avg_sequential_ms']:.2f} ms{' ' * 8} {result['avg_batch_ms']:.2f} ms{' ' * 8} {result['speedup_factor']:.2f}x"
        )


if __name__ == "__main__":
    random.seed(0)
    main()