test_evo2_generation_batched.py 9.95 KB
Newer Older
1
2
3
4
import argparse
import csv
from importlib import resources
from pathlib import Path
one's avatar
one committed
5
from typing import Optional
6
7
8
9
10
11
import numpy as np
import time
import torch

from evo2 import Evo2

one's avatar
one committed
12

13
14
def read_prompts(input_file):
    """Read prompts from input file or built-in test data.
one's avatar
one committed
15

16
17
18
19
20
21
22
    Args:
        input_file: Either a path to a file, or the name of a test data file
                   (e.g., 'prompts.csv')
    """
    # If it's a string that doesn't exist as a file path, assume it's a test data file
    if isinstance(input_file, str) and not Path(input_file).is_file():
        # This is the reliable way to get package data
one's avatar
one committed
23
        with resources.path("evo2.test.data", input_file) as data_path:
24
            input_file = data_path
one's avatar
one committed
25

26
27
    # Your existing code to read the file
    promptseqs = []
one's avatar
one committed
28
    with open(input_file, encoding="utf-8-sig", newline="") as csvfile:
29
30
31
32
33
34
        reader = csv.reader(csvfile)
        next(reader)  # Skip header
        for row in reader:
            promptseqs.append(row[0])
    return promptseqs

one's avatar
one committed
35

36
37
def mid_point_split(*, seq, num_tokens):
    """Split sequence at midpoint for prompt and target."""
one's avatar
one committed
38
    mid_point = 2 * (len(seq) // 4)
39
    prompt = seq[:mid_point]
one's avatar
one committed
40
    target = seq[mid_point : mid_point + num_tokens]
41
42
    return prompt, target

one's avatar
one committed
43

44
45
46
47
def calculate_sequence_identity(seq1: str, seq2: str) -> Optional[float]:
    """Calculate sequence identity between two sequences through direct comparison."""
    if not seq1 or not seq2:
        return None
one's avatar
one committed
48

49
50
51
52
    min_length = min(len(seq1), len(seq2))
    matches = sum(a == b for a, b in zip(seq1[:min_length], seq2[:min_length]))
    return (matches / min_length) * 100

one's avatar
one committed
53
54
55
56
57
58
59
60
61
62
63
64

def generate_and_score(
    *,
    sequences,
    model,
    generations_per_prompt=5,
    n_tokens=500,
    temperature=1.0,
    top_k=1,
    top_p=1.0,
    batch_size=2,
):
65
66
67
68
    """Prompt with first half, generate and score on 2nd half."""
    scores = []
    prompts = []
    targets = []
one's avatar
one committed
69

70
71
72
73
74
    # Prepare all prompts and targets
    for seq in sequences:
        prompt, target = mid_point_split(seq=seq, num_tokens=n_tokens)
        prompts.extend([prompt] * generations_per_prompt)
        targets.extend([target] * generations_per_prompt)
one's avatar
one committed
75

76
    for i in range(0, len(prompts), batch_size):
one's avatar
one committed
77
78
        batch_prompts = prompts[i : i + batch_size]
        batch_targets = targets[i : i + batch_size]
79
80

        with torch.inference_mode():
one's avatar
one committed
81
82
            torch.cuda.synchronize()
            step_time = -time.perf_counter()
83
84
85
86
87
88
89
            generated = model.generate(
                prompt_seqs=batch_prompts,
                n_tokens=n_tokens,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
one's avatar
one committed
90
91
92
            torch.cuda.synchronize()
            step_time += time.perf_counter()
            print(
93
                f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate (batch_size={batch_size}): {step_time:.3f} s"
one's avatar
one committed
94
            )
95
96
97
98

            for j, decoded_seq in enumerate(generated.sequences):
                score = calculate_sequence_identity(decoded_seq, batch_targets[j])
                scores.append(score)
one's avatar
one committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    # Reshape scores to group by original sequence
    reshaped_scores = [
        scores[i : i + generations_per_prompt]
        for i in range(0, len(scores), generations_per_prompt)
    ]

    return reshaped_scores


def generate_and_score_prof(
    *,
    sequences,
    model,
    generations_per_prompt=5,
    n_tokens=500,
    temperature=1.0,
    top_k=1,
    top_p=1.0,
    batch_size=2,
    trace_step=1,
120
121
    trace_logdir="./log/pt-trace/",
    trace_gzip=False,
122
    trace_file_prefix=None,
one's avatar
one committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
):
    """Prompt with first half, generate and score on 2nd half with torch profiler.

    Profiler is enabled only for iteration i==1 to capture detailed performance data.
    """
    scores = []
    prompts = []
    targets = []

    # Prepare all prompts and targets
    for seq in sequences:
        prompt, target = mid_point_split(seq=seq, num_tokens=n_tokens)
        prompts.extend([prompt] * generations_per_prompt)
        targets.extend([target] * generations_per_prompt)

    print("\n[TRACE] Start profiling...")

140
    # 按需开启功能
one's avatar
one committed
141
142
    with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=0, warmup=trace_step, active=1, repeat=1),
143
144
145
        on_trace_ready=torch.profiler.tensorboard_trace_handler(
            dir_name=trace_logdir, worker_name=trace_file_prefix, use_gzip=trace_gzip
        ),
146
147
148
149
150
        record_shapes=False,
        profile_memory=False,
        with_stack=False,
        with_flops=False,
        with_modules=False,
one's avatar
one committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    ) as prof:
        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i : i + batch_size]
            batch_targets = targets[i : i + batch_size]

            with torch.inference_mode():
                torch.cuda.synchronize()
                step_time = -time.perf_counter()
                generated = model.generate(
                    prompt_seqs=batch_prompts,
                    n_tokens=n_tokens,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                )
                torch.cuda.synchronize()
                step_time += time.perf_counter()
                print(
169
                    f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate (batch_size={batch_size}): {step_time:.3f} s"
one's avatar
one committed
170
171
172
173
174
175
176
                )

                for j, decoded_seq in enumerate(generated.sequences):
                    score = calculate_sequence_identity(decoded_seq, batch_targets[j])
                    scores.append(score)
            prof.step()

177
    # Reshape scores to group by original sequence
one's avatar
one committed
178
179
180
181
182
    reshaped_scores = [
        scores[i : i + generations_per_prompt]
        for i in range(0, len(scores), generations_per_prompt)
    ]

183
184
    return reshaped_scores

one's avatar
one committed
185

186
187
188
189
190
def main():
    """
    Test sequence generation and scoring using the evo2 models
    Expected results (direct comparison w/o alignment):
    - Evo 2 40B 1m: 91.15%
one's avatar
one committed
191
    - Evo 2 7B 1m: 89.25%
192
193
194
    - Evo 2 1B base: 68.0%
    """
    parser = argparse.ArgumentParser(description="Test Evo2 Model Generation")
one's avatar
one committed
195
196
197
198
199
200
    parser.add_argument(
        "--model_name",
        choices=["evo2_7b", "evo2_40b", "evo2_1b_base"],
        default="evo2_7b",
        help="Model to test (supports evo2_7b, evo2_40b, evo2_1b_base)",
    )
201
    parser.add_argument("--local_path", type=str, default=None)
one's avatar
one committed
202
203
204
205
206
207
    parser.add_argument(
        "--n_tokens", type=int, default=500, help="Number of tokens to generate"
    )
    parser.add_argument(
        "--batch_size", type=int, default=1, help="Batch size for generation"
    )
208
209
210
211
212
    parser.add_argument(
        "--prompt_stretch",
        action="store_true",
        help="Stretch all prompts to the longest prompt length",
    )
213
214
215
216
217
218
    parser.add_argument(
        "--n_warmups",
        type=int,
        default=0,
        help="Number of warmups to run",
    )
one's avatar
one committed
219
220
221
222
223
224
225
226
227
228
229
    parser.add_argument(
        "--trace",
        action="store_true",
        help="Enable torch profiler",
    )
    parser.add_argument(
        "--trace_step",
        type=int,
        default=1,
        help="Attach torch profiler to specific step (default: 1)",
    )
230
231
232
233
234
235
236
237
    parser.add_argument(
        "--trace_logdir",
        type=str,
        default="./log/pt-trace/",
        help="Directory for torch profiler trace output (default: ./log/pt-trace/)",
    )
    parser.add_argument(
        "--trace_gzip",
238
239
        action="store_true",
        help="Gzip torch profiler trace output",
240
    )
241
242
243
244
245
246
    parser.add_argument(
        "--trace_file_prefix",
        type=str,
        default=None,
        help="Prefix for torch profiler trace output file",
    )
one's avatar
one committed
247

248
    args = parser.parse_args()
one's avatar
one committed
249

250
251
252
    # Set random seeds
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
one's avatar
one committed
253

254
255
256
257
    model = Evo2(args.model_name, local_path=args.local_path)

    # Test parameters: greedy sampling of 500 tokens
    test_params = {
one's avatar
one committed
258
259
260
261
262
263
        "n_tokens": args.n_tokens,
        "temperature": 1.0,
        "top_k": 1,
        "top_p": 1.0,
        "generations_per_prompt": 1,
        "batch_size": args.batch_size,
264
    }
one's avatar
one committed
265

266
    # Read and process sequences
one's avatar
one committed
267
    sequences = read_prompts("prompts.csv")
268
269
    print("[DEBUG] Prompt lengths:", [len(seq) for seq in sequences])

one's avatar
one committed
270
    # Debugging: replace all prompts with the longest prompt
one's avatar
one committed
271
    if args.prompt_stretch or args.batch_size > 1:
272
273
        uniform_prompt = sequences[1]  # length=7056
        sequences = [uniform_prompt] * len(sequences)
one's avatar
one committed
274
        print(
275
            f"[DEBUG] Using the uniform prompt with length {len(uniform_prompt)} for all sequences"
one's avatar
one committed
276
277
        )

278
279
280
281
282
283
284
    # Warmup
    if args.n_warmups > 0:
        warmup_sequences = sequences[:1] * args.n_warmups
        warmup_params = {**test_params, "n_tokens": 16}
        generate_and_score(sequences=warmup_sequences, model=model, **warmup_params)
        print(f"[DEBUG] Running {args.n_warmups} warmups with the first prompt")

one's avatar
one committed
285
286
287
288
289
290
    if args.trace:
        print("[TRACE] Using generate_and_score_prof with torch profiler")
        scores = generate_and_score_prof(
            sequences=sequences,
            model=model,
            trace_step=args.trace_step,
291
            trace_gzip=args.trace_gzip,
292
293
            trace_logdir=args.trace_logdir,
            trace_file_prefix=args.trace_file_prefix,
one's avatar
one committed
294
295
296
297
            **test_params,
        )
    else:
        scores = generate_and_score(sequences=sequences, model=model, **test_params)
one's avatar
one committed
298

299
300
301
302
    # Calculate and validate results
    mean_score = np.mean(scores)
    print("\nTest Results:")
    print("% Matching Nucleotides:", mean_score)
one's avatar
one committed
303

304
305
    # Validate against expected scores
    eps = 3  # large epsilon for direct comparison, since there are numeric differences by versions
one's avatar
one committed
306
307
    expected_scores = {"evo2_40b": 91.15, "evo2_7b": 89.25, "evo2_1b_base": 68.0}

308
309
310
311
312
313
    expected_score = expected_scores[args.model_name]
    if abs(mean_score - expected_score) < eps:
        print(f"\nTest Passed! Score matches expected {expected_score}%")
    else:
        print(f"\nTest Failed: Expected {expected_score}%, got {mean_score}%")

one's avatar
one committed
314

315
316
if __name__ == "__main__":
    main()