"website/docusaurus.config.js" did not exist on "832e392f91085206dd22ae0bd6c9ed499e2d0db6"
evo2_fix.patch 3.02 KB
Newer Older
one's avatar
one committed
1
2
3
4
5
6
7
8
9
10
11
12
13
diff --git a/evo2/configs/evo2-7b-1m.yml b/evo2/configs/evo2-7b-1m.yml
index 5ee5461..4fc408f 100644
--- a/evo2/configs/evo2-7b-1m.yml
+++ b/evo2/configs/evo2-7b-1m.yml
@@ -49,7 +49,7 @@ mha_out_proj_bias: True
 hyena_out_proj_bias: True
 hyena_flip_x1x2: False
 qkv_proj_bias: False
-use_fp8_input_projections: True
+use_fp8_input_projections: False
 max_seqlen: 1048576
 max_batch_size: 1
 final_norm: True 
one's avatar
one committed
14
diff --git a/evo2/test/test_evo2_generation.py b/evo2/test/test_evo2_generation.py
one's avatar
one committed
15
index bc420b7..928854f 100644
one's avatar
one committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
--- a/evo2/test/test_evo2_generation.py
+++ b/evo2/test/test_evo2_generation.py
@@ -4,6 +4,7 @@ from importlib import resources
 from pathlib import Path
 from typing import List, Optional, Union
 import numpy as np
+import time
 
 import torch
 
@@ -65,6 +66,8 @@ def generate_and_score(*, sequences, model, generations_per_prompt=5, n_tokens=5
         target = targets[i]
 
         with torch.inference_mode():
+            if torch.cuda.is_available(): torch.cuda.synchronize()
+            elapsed_time = -time.perf_counter()
             generated = model.generate(
                 prompt_seqs=[prompt],
                 n_tokens=n_tokens,
one's avatar
one committed
35
@@ -72,6 +75,9 @@ def generate_and_score(*, sequences, model, generations_per_prompt=5, n_tokens=5
one's avatar
one committed
36
37
38
39
40
41
                 top_k=top_k,
                 top_p=top_p,
             )
+            if torch.cuda.is_available(): torch.cuda.synchronize()
+            elapsed_time += time.perf_counter()
+            print(f"[{i}] Time for model.generate: {elapsed_time:.3f} s")
one's avatar
one committed
42
             
one's avatar
one committed
43
44
             decoded_seq = generated.sequences[0]  # Assuming generate returns list of sequences
             score = calculate_sequence_identity(decoded_seq, target)
one's avatar
one committed
45
@@ -95,7 +101,7 @@ def main():
one's avatar
one committed
46
     parser = argparse.ArgumentParser(description="Test Evo2 Model Generation")
one's avatar
one committed
47
48
49
     parser.add_argument("--model_name", choices=['evo2_7b', 'evo2_40b', 'evo2_1b_base', 'evo2_20b'], default='evo2_7b',
                        help="Model to test (supports evo2_7b, evo2_40b, evo2_1b_base, evo2_20b)")
-    
one's avatar
one committed
50
51
52
+    parser.add_argument("--local_path", type=str, default=None)
     args = parser.parse_args()
     
one's avatar
one committed
53
54
     # Reduce CUDA memory fragmentation for large models (e.g. evo2_20b)
@@ -105,7 +111,7 @@ def main():
one's avatar
one committed
55
56
57
58
59
     torch.manual_seed(1)
     torch.cuda.manual_seed(1)
         
-    model = Evo2(args.model_name)
+    model = Evo2(args.model_name, local_path=args.local_path)
one's avatar
one committed
60
     
one's avatar
one committed
61
62
     # Test parameters: greedy sampling of 500 tokens
     test_params = {
one's avatar
one committed
63
@@ -145,4 +151,4 @@ def main():
one's avatar
one committed
64
65
66
67
68
69
70
         print(f"\nTest Failed: Expected {expected_score}%, got {mean_score}%")
 
 if __name__ == "__main__":
-    main()
\ No newline at end of file
+    main()
diff --git a/pyproject.toml b/pyproject.toml
one's avatar
one committed
71
index 4347bbc..77b3f6f 100644
one's avatar
one committed
72
73
--- a/pyproject.toml
+++ b/pyproject.toml
one's avatar
one committed
74
75
@@ -7,7 +7,7 @@ name = "evo2"
 version = "0.5.3"
one's avatar
one committed
76
77
78
79
 description = "Genome modeling across all domains of life"
 readme = "README.md"
-requires-python = ">=3.11,<3.13"
+requires-python = ">=3.10,<3.13"
one's avatar
one committed
80
 license = {file = "LICENSE"}
one's avatar
one committed
81
82
 authors = [
     {name = "Evo 2 Team"},