benchmark_tokenizer.py 4.23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import random
import time
from statistics import mean

from transformers import AutoTokenizer

# CONFIG
TOKENIZER_DIR = (
    "/shared/public/sharing/fait360brew/training/models/meta-llama/Llama-3.2-3B"
)
NUM_TOKENS = 20000  # Each prompt should contain this many tokens
BATCH_SIZES = [1, 2, 4, 8]  # Test different batch sizes
NUM_RUNS = 5  # Number of runs for each batch size to get reliable measurements


def generate_random_prompts(num_prompts, num_tokens, tokenizer):
    """Generate random prompts with specified token count."""
    vocab_size = tokenizer.vocab_size
    all_prompts = []

    print(f"Generating {num_prompts} random prompts with {num_tokens} tokens each...")
    for i in range(num_prompts):
        # Generate random token IDs - this directly gives us the exact token count
        random_token_ids = [
            random.randint(0, vocab_size - 1) for _ in range(num_tokens)
        ]
        random_text = tokenizer.decode(
            random_token_ids, clean_up_tokenization_spaces=True
        )

        prompt = f"Prompt {i}: {random_text}"
        tokens = tokenizer.encode(prompt)
        print(f"  Prompt {i}: {len(tokens)} tokens")
        all_prompts.append(prompt)

    return all_prompts


def benchmark_sequential_vs_batch(prompts, batch_size, tokenizer):
    """Compare sequential vs batch tokenization for a given batch size."""

    # Sequential tokenization using encode()
    sequential_times = []
    for run in range(NUM_RUNS):
        batch_prompts = prompts[:batch_size]  # Use same prompts for fair comparison

        start_time = time.time()
        for prompt in batch_prompts:
            tokens = tokenizer.encode(prompt)
        sequential_time = (time.time() - start_time) * 1000
        sequential_times.append(sequential_time)

    # Batch tokenization using tokenizer()
    batch_times = []
    for run in range(NUM_RUNS):
        batch_prompts = prompts[:batch_size]  # Use same prompts for fair comparison

        start_time = time.time()
        tokens = tokenizer(batch_prompts)
        batch_time = (time.time() - start_time) * 1000
        batch_times.append(batch_time)

    return {
        "batch_size": batch_size,
        "avg_sequential_ms": mean(sequential_times),
        "avg_batch_ms": mean(batch_times),
        "speedup_factor": (
            mean(sequential_times) / mean(batch_times) if mean(batch_times) > 0 else 0
        ),
        "sequential_runs": sequential_times,
        "batch_runs": batch_times,
    }


def main():
    print("Tokenizer Benchmark: Sequential vs Batch Processing")
    print("-" * 60)
    print(f"Tokenizer: {TOKENIZER_DIR}")
    print(f"Tokens per prompt: {NUM_TOKENS}")
    print(f"Number of runs per batch size: {NUM_RUNS}")
    print("-" * 60)

    # Load tokenizer once for all operations
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)

    # The largest batch size determines how many prompts we need
    max_batch_size = max(BATCH_SIZES)
    all_prompts = generate_random_prompts(max_batch_size, NUM_TOKENS, tokenizer)

    results = []
    print("\nRunning benchmark...")

    for batch_size in BATCH_SIZES:
        print(f"\nBenchmarking batch size: {batch_size}")
        result = benchmark_sequential_vs_batch(all_prompts, batch_size, tokenizer)
        results.append(result)

        print(f"  Sequential tokenization (encode):")
        for i, run_time in enumerate(result["sequential_runs"]):
            print(f"    Run {i+1}: {run_time:.2f} ms")
        print(f"    Average: {result['avg_sequential_ms']:.2f} ms")

        print(f"  Batch tokenization (tokenizer):")
        for i, run_time in enumerate(result["batch_runs"]):
            print(f"    Run {i+1}: {run_time:.2f} ms")
        print(f"    Average: {result['avg_batch_ms']:.2f} ms")

        print(f"  Speedup factor: {result['speedup_factor']:.2f}x")

    print("\n" + "=" * 60)
    print("SUMMARY OF RESULTS")
    print("=" * 60)
    print(
        f"{'Batch Size':<10} {'Sequential (ms)':<18} {'Batch (ms)':<18} {'Speedup':<10}"
    )
    print("-" * 60)

    for result in results:
        print(
            f"{result['batch_size']:<10} {result['avg_sequential_ms']:.2f} ms{' ' * 8} {result['avg_batch_ms']:.2f} ms{' ' * 8} {result['speedup_factor']:.2f}x"
        )


if __name__ == "__main__":
    random.seed(0)
    main()