test_evo2_generation_batched.py 5.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
import csv
from importlib import resources
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import time

import torch

from evo2 import Evo2

def read_prompts(input_file):
    """Read prompts from input file or built-in test data.
    
    Args:
        input_file: Either a path to a file, or the name of a test data file
                   (e.g., 'prompts.csv')
    """
    # If it's a string that doesn't exist as a file path, assume it's a test data file
    if isinstance(input_file, str) and not Path(input_file).is_file():
        # This is the reliable way to get package data
        with resources.path('evo2.test.data', input_file) as data_path:
            input_file = data_path
    
    # Your existing code to read the file
    promptseqs = []
    with open(input_file, encoding='utf-8-sig', newline='') as csvfile:
        reader = csv.reader(csvfile)
        next(reader)  # Skip header
        for row in reader:
            promptseqs.append(row[0])
    return promptseqs

def mid_point_split(*, seq, num_tokens):
    """Split sequence at midpoint for prompt and target."""
    mid_point = 2*(len(seq)//4)
    prompt = seq[:mid_point]
    target = seq[mid_point:mid_point+num_tokens]
    return prompt, target

def calculate_sequence_identity(seq1: str, seq2: str) -> Optional[float]:
    """Calculate sequence identity between two sequences through direct comparison."""
    if not seq1 or not seq2:
        return None
    
    min_length = min(len(seq1), len(seq2))
    matches = sum(a == b for a, b in zip(seq1[:min_length], seq2[:min_length]))
    return (matches / min_length) * 100

def generate_and_score(*, sequences, model, generations_per_prompt=5, n_tokens=500,
                      temperature=1.0, top_k=1, top_p=1.0, batch_size=2):
    """Prompt with first half, generate and score on 2nd half."""
    scores = []
    prompts = []
    targets = []
    
    # Prepare all prompts and targets
    for seq in sequences:
        prompt, target = mid_point_split(seq=seq, num_tokens=n_tokens)
        prompts.extend([prompt] * generations_per_prompt)
        targets.extend([target] * generations_per_prompt)
    
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i + batch_size]
        batch_targets = targets[i:i + batch_size]

        with torch.inference_mode():
            if torch.cuda.is_available(): torch.cuda.synchronize()
            elapsed_time = -time.perf_counter()
            generated = model.generate(
                prompt_seqs=batch_prompts,
                n_tokens=n_tokens,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
            )
            if torch.cuda.is_available(): torch.cuda.synchronize()
            elapsed_time += time.perf_counter()
            print(f"[{i}:{min(i+batch_size, len(prompts))}) Time for model.generate: {elapsed_time:.3f} s")

            for j, decoded_seq in enumerate(generated.sequences):
                score = calculate_sequence_identity(decoded_seq, batch_targets[j])
                scores.append(score)
    
    # Reshape scores to group by original sequence
    reshaped_scores = [scores[i:i + generations_per_prompt] 
                      for i in range(0, len(scores), generations_per_prompt)]
    
    return reshaped_scores

def main():
    """
    Test sequence generation and scoring using the evo2 models
    Expected results (direct comparison w/o alignment):
    - Evo 2 40B 1m: 91.15%
    - Evo 2 7B 1m: 89.25% 
    - Evo 2 1B base: 68.0%
    """
    parser = argparse.ArgumentParser(description="Test Evo2 Model Generation")
    parser.add_argument("--model_name", choices=['evo2_7b', 'evo2_40b', 'evo2_1b_base'], default='evo2_7b',
                       help="Model to test (supports evo2_7b, evo2_40b, evo2_1b_base)")
    parser.add_argument("--local_path", type=str, default=None)
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size for generation")
    
    args = parser.parse_args()
    
    # Set random seeds
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)
        
    model = Evo2(args.model_name, local_path=args.local_path)

    # Test parameters: greedy sampling of 500 tokens
    test_params = {
        'n_tokens': 500,
        'temperature': 1.0,
        'top_k': 1,
        'top_p': 1.0,
        'generations_per_prompt': 1,
        'batch_size': args.batch_size,
    }
    
    # Read and process sequences
    sequences = read_prompts('prompts.csv')
    # DEBUG: replace all prompts with the longest prompt to enable uniform lengths
    longest_prompt = max(sequences, key=len)
    sequences = [longest_prompt] * len(sequences)
    print(f"[debug] Using longest prompt len={len(longest_prompt)} for all sequences")
    scores = generate_and_score(
        sequences=sequences,
        model=model,
        **test_params
    )
    
    # Calculate and validate results
    mean_score = np.mean(scores)
    print("\nTest Results:")
    print("% Matching Nucleotides:", mean_score)
    
    # Validate against expected scores
    eps = 3  # large epsilon for direct comparison, since there are numeric differences by versions
    expected_scores = {
        'evo2_40b': 91.15,
        'evo2_7b': 89.25,
        'evo2_1b_base': 68.0
    }
    
    expected_score = expected_scores[args.model_name]
    if abs(mean_score - expected_score) < eps:
        print(f"\nTest Passed! Score matches expected {expected_score}%")
    else:
        print(f"\nTest Failed: Expected {expected_score}%, got {mean_score}%")

if __name__ == "__main__":
    main()