"git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "74c4d1b2a4038eab35483dda0078bd4902f18568"
Commit 09ab8fa4 authored by PanZezhong's avatar PanZezhong
Browse files

issue/168 fix bench.py

parent db19cc0b
......@@ -3,6 +3,7 @@ from transformers import AutoTokenizer
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig
import argparse
import sys
import time
......@@ -260,6 +261,7 @@ class TestModel:
output_ids = self.model.generate(
input_ids_infini,
GenerationConfig(max_new_tokens=output_len, eos_token_id=[]),
_measure_and_log_time=True,
)
t2 = time.time()
......@@ -336,7 +338,11 @@ if __name__ == "__main__":
# reset cache for each case
initial_capacity = input_len + output_len
test.model.reset_cache(batch_size=batch_size, initial_capacity=initial_capacity)
test.model.reset_cache(
StaticKVCacheConfig(
max_batch_size=batch_size, max_cache_len=initial_capacity
)
)
# run test one case
test.run(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment