test_evo2_generation_batched.py 10.2 KB
Newer Older
1
2
3
4
import argparse
import csv
from importlib import resources
from pathlib import Path
one's avatar
one committed
5
from typing import Optional
6
7
import numpy as np
import time
one's avatar
one committed
8

9
10
11
12
import torch

from evo2 import Evo2

one's avatar
one committed
13

14
15
def read_prompts(input_file):
    """Read prompts from input file or built-in test data.
one's avatar
one committed
16

17
18
19
20
21
22
23
    Args:
        input_file: Either a path to a file, or the name of a test data file
                   (e.g., 'prompts.csv')
    """
    # If it's a string that doesn't exist as a file path, assume it's a test data file
    if isinstance(input_file, str) and not Path(input_file).is_file():
        # This is the reliable way to get package data
one's avatar
one committed
24
        with resources.path("evo2.test.data", input_file) as data_path:
25
            input_file = data_path
one's avatar
one committed
26

27
28
    # Your existing code to read the file
    promptseqs = []
one's avatar
one committed
29
    with open(input_file, encoding="utf-8-sig", newline="") as csvfile:
30
31
32
33
34
35
        reader = csv.reader(csvfile)
        next(reader)  # Skip header
        for row in reader:
            promptseqs.append(row[0])
    return promptseqs

one's avatar
one committed
36

37
38
def mid_point_split(*, seq, num_tokens):
    """Split sequence at midpoint for prompt and target."""
one's avatar
one committed
39
    mid_point = 2 * (len(seq) // 4)
40
    prompt = seq[:mid_point]
one's avatar
one committed
41
    target = seq[mid_point : mid_point + num_tokens]
42
43
    return prompt, target

one's avatar
one committed
44

45
46
47
48
def calculate_sequence_identity(seq1: str, seq2: str) -> Optional[float]:
    """Calculate sequence identity between two sequences through direct comparison."""
    if not seq1 or not seq2:
        return None
one's avatar
one committed
49

50
51
52
53
    min_length = min(len(seq1), len(seq2))
    matches = sum(a == b for a, b in zip(seq1[:min_length], seq2[:min_length]))
    return (matches / min_length) * 100

one's avatar
one committed
54
55
56
57
58
59
60
61
62
63
64
65

def generate_and_score(
    *,
    sequences,
    model,
    generations_per_prompt=5,
    n_tokens=500,
    temperature=1.0,
    top_k=1,
    top_p=1.0,
    batch_size=2,
):
66
67
68
69
    """Prompt with first half, generate and score on 2nd half."""
    scores = []
    prompts = []
    targets = []
one's avatar
one committed
70

71
72
73
74
75
    # Prepare all prompts and targets
    for seq in sequences:
        prompt, target = mid_point_split(seq=seq, num_tokens=n_tokens)
        prompts.extend([prompt] * generations_per_prompt)
        targets.extend([target] * generations_per_prompt)
one's avatar
one committed
76

77
    for i in range(0, len(prompts), batch_size):
one's avatar
one committed
78
79
        batch_prompts = prompts[i : i + batch_size]
        batch_targets = targets[i : i + batch_size]
80
81

        with torch.inference_mode():
one's avatar
one committed
82
83
            torch.cuda.synchronize()
            step_time = -time.perf_counter()
84
85
86
87
88
89
90
            generated = model.generate(
                prompt_seqs=batch_prompts,
                n_tokens=n_tokens,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
one's avatar
one committed
91
92
93
            torch.cuda.synchronize()
            step_time += time.perf_counter()
            print(
94
                f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate (batch_size={batch_size}): {step_time:.3f} s"
one's avatar
one committed
95
            )
96
97
98
99

            for j, decoded_seq in enumerate(generated.sequences):
                score = calculate_sequence_identity(decoded_seq, batch_targets[j])
                scores.append(score)
one's avatar
one committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

    # Reshape scores to group by original sequence
    reshaped_scores = [
        scores[i : i + generations_per_prompt]
        for i in range(0, len(scores), generations_per_prompt)
    ]

    return reshaped_scores


def generate_and_score_prof(
    *,
    sequences,
    model,
    generations_per_prompt=5,
    n_tokens=500,
    temperature=1.0,
    top_k=1,
    top_p=1.0,
    batch_size=2,
    trace_step=1,
121
122
    trace_logdir="./log/pt-trace/",
    trace_gzip=False,
123
    trace_file_prefix=None,
one's avatar
one committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
):
    """Prompt with first half, generate and score on 2nd half with torch profiler.

    Profiler is enabled only for iteration i==1 to capture detailed performance data.
    """
    scores = []
    prompts = []
    targets = []

    # Prepare all prompts and targets
    for seq in sequences:
        prompt, target = mid_point_split(seq=seq, num_tokens=n_tokens)
        prompts.extend([prompt] * generations_per_prompt)
        targets.extend([target] * generations_per_prompt)

    print("\n[TRACE] Start profiling...")

141
    # 按需开启功能
one's avatar
one committed
142
143
    with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=0, warmup=trace_step, active=1, repeat=1),
144
145
146
        on_trace_ready=torch.profiler.tensorboard_trace_handler(
            dir_name=trace_logdir, worker_name=trace_file_prefix, use_gzip=trace_gzip
        ),
147
148
149
150
151
        record_shapes=False,
        profile_memory=False,
        with_stack=False,
        with_flops=False,
        with_modules=False,
one's avatar
one committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    ) as prof:
        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i : i + batch_size]
            batch_targets = targets[i : i + batch_size]

            with torch.inference_mode():
                torch.cuda.synchronize()
                step_time = -time.perf_counter()
                generated = model.generate(
                    prompt_seqs=batch_prompts,
                    n_tokens=n_tokens,
                    temperature=temperature,
                    top_k=top_k,
                    top_p=top_p,
                )
                torch.cuda.synchronize()
                step_time += time.perf_counter()
                print(
170
                    f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate (batch_size={batch_size}): {step_time:.3f} s"
one's avatar
one committed
171
172
173
174
175
176
177
                )

                for j, decoded_seq in enumerate(generated.sequences):
                    score = calculate_sequence_identity(decoded_seq, batch_targets[j])
                    scores.append(score)
            prof.step()

178
    # Reshape scores to group by original sequence
one's avatar
one committed
179
180
181
182
183
    reshaped_scores = [
        scores[i : i + generations_per_prompt]
        for i in range(0, len(scores), generations_per_prompt)
    ]

184
185
    return reshaped_scores

one's avatar
one committed
186

187
188
189
190
191
def main():
    """
    Test sequence generation and scoring using the evo2 models
    Expected results (direct comparison w/o alignment):
    - Evo 2 40B 1m: 91.15%
one's avatar
one committed
192
    - Evo 2 7B 1m: 89.25%
193
    - Evo 2 1B base: 68.0%
one's avatar
one committed
194
    - Evo 2 20B 1m: 93.4%
195
196
    """
    parser = argparse.ArgumentParser(description="Test Evo2 Model Generation")
one's avatar
one committed
197
198
    parser.add_argument(
        "--model_name",
one's avatar
one committed
199
        choices=["evo2_7b", "evo2_40b", "evo2_1b_base", "evo2_20b"],
one's avatar
one committed
200
        default="evo2_7b",
one's avatar
one committed
201
        help="Model to test (supports evo2_7b, evo2_40b, evo2_1b_base, evo2_20b)",
one's avatar
one committed
202
    )
203
    parser.add_argument("--local_path", type=str, default=None)
one's avatar
one committed
204
205
206
207
208
209
    parser.add_argument(
        "--n_tokens", type=int, default=500, help="Number of tokens to generate"
    )
    parser.add_argument(
        "--batch_size", type=int, default=1, help="Batch size for generation"
    )
210
211
212
213
214
    parser.add_argument(
        "--prompt_stretch",
        action="store_true",
        help="Stretch all prompts to the longest prompt length",
    )
215
216
217
218
219
220
    parser.add_argument(
        "--n_warmups",
        type=int,
        default=0,
        help="Number of warmups to run",
    )
one's avatar
one committed
221
222
223
224
225
226
227
228
229
230
231
    parser.add_argument(
        "--trace",
        action="store_true",
        help="Enable torch profiler",
    )
    parser.add_argument(
        "--trace_step",
        type=int,
        default=1,
        help="Attach torch profiler to specific step (default: 1)",
    )
232
233
234
235
236
237
238
239
    parser.add_argument(
        "--trace_logdir",
        type=str,
        default="./log/pt-trace/",
        help="Directory for torch profiler trace output (default: ./log/pt-trace/)",
    )
    parser.add_argument(
        "--trace_gzip",
240
241
        action="store_true",
        help="Gzip torch profiler trace output",
242
    )
243
244
245
246
247
248
    parser.add_argument(
        "--trace_file_prefix",
        type=str,
        default=None,
        help="Prefix for torch profiler trace output file",
    )
one's avatar
one committed
249

250
    args = parser.parse_args()
one's avatar
one committed
251

one's avatar
one committed
252
253
254
    # Reduce CUDA memory fragmentation for large models (e.g. evo2_20b)
    torch.cuda.memory._set_allocator_settings("expandable_segments:True")

255
256
257
    # Set random seeds
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
one's avatar
one committed
258

259
260
261
262
    model = Evo2(args.model_name, local_path=args.local_path)

    # Test parameters: greedy sampling of 500 tokens
    test_params = {
one's avatar
one committed
263
264
265
266
267
268
        "n_tokens": args.n_tokens,
        "temperature": 1.0,
        "top_k": 1,
        "top_p": 1.0,
        "generations_per_prompt": 1,
        "batch_size": args.batch_size,
269
    }
one's avatar
one committed
270

271
    # Read and process sequences
one's avatar
one committed
272
    sequences = read_prompts("prompts.csv")
273
274
    print("[DEBUG] Prompt lengths:", [len(seq) for seq in sequences])

one's avatar
one committed
275
    # Debugging: replace all prompts with the longest prompt
one's avatar
one committed
276
    if args.prompt_stretch or args.batch_size > 1:
277
278
        uniform_prompt = sequences[1]  # length=7056
        sequences = [uniform_prompt] * len(sequences)
one's avatar
one committed
279
        print(
280
            f"[DEBUG] Using the uniform prompt with length {len(uniform_prompt)} for all sequences"
one's avatar
one committed
281
282
        )

283
284
285
286
287
288
289
    # Warmup
    if args.n_warmups > 0:
        warmup_sequences = sequences[:1] * args.n_warmups
        warmup_params = {**test_params, "n_tokens": 16}
        generate_and_score(sequences=warmup_sequences, model=model, **warmup_params)
        print(f"[DEBUG] Running {args.n_warmups} warmups with the first prompt")

one's avatar
one committed
290
291
292
293
294
295
    if args.trace:
        print("[TRACE] Using generate_and_score_prof with torch profiler")
        scores = generate_and_score_prof(
            sequences=sequences,
            model=model,
            trace_step=args.trace_step,
296
            trace_gzip=args.trace_gzip,
297
298
            trace_logdir=args.trace_logdir,
            trace_file_prefix=args.trace_file_prefix,
one's avatar
one committed
299
300
301
302
            **test_params,
        )
    else:
        scores = generate_and_score(sequences=sequences, model=model, **test_params)
one's avatar
one committed
303

304
305
306
307
    # Calculate and validate results
    mean_score = np.mean(scores)
    print("\nTest Results:")
    print("% Matching Nucleotides:", mean_score)
one's avatar
one committed
308

309
310
    # Validate against expected scores
    eps = 3  # large epsilon for direct comparison, since there are numeric differences by versions
one's avatar
one committed
311
312
313
314
315
316
    expected_scores = {
        "evo2_40b": 91.15,
        "evo2_7b": 89.25,
        "evo2_1b_base": 68.0,
        "evo2_20b": 93.4,
    }
one's avatar
one committed
317

318
319
320
321
322
323
    expected_score = expected_scores[args.model_name]
    if abs(mean_score - expected_score) < eps:
        print(f"\nTest Passed! Score matches expected {expected_score}%")
    else:
        print(f"\nTest Failed: Expected {expected_score}%, got {mean_score}%")

one's avatar
one committed
324

325
326
if __name__ == "__main__":
    main()