Commit 09ab8fa4 authored by PanZezhong's avatar PanZezhong
Browse files

issue/168 fix bench.py

parent db19cc0b
...@@ -3,6 +3,7 @@ from transformers import AutoTokenizer ...@@ -3,6 +3,7 @@ from transformers import AutoTokenizer
from infinilm.modeling_utils import load_model_state_dict_by_file from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig
import argparse import argparse
import sys import sys
import time import time
...@@ -260,6 +261,7 @@ class TestModel: ...@@ -260,6 +261,7 @@ class TestModel:
output_ids = self.model.generate( output_ids = self.model.generate(
input_ids_infini, input_ids_infini,
GenerationConfig(max_new_tokens=output_len, eos_token_id=[]), GenerationConfig(max_new_tokens=output_len, eos_token_id=[]),
_measure_and_log_time=True,
) )
t2 = time.time() t2 = time.time()
...@@ -336,7 +338,11 @@ if __name__ == "__main__": ...@@ -336,7 +338,11 @@ if __name__ == "__main__":
# reset cache for each case # reset cache for each case
initial_capacity = input_len + output_len initial_capacity = input_len + output_len
test.model.reset_cache(batch_size=batch_size, initial_capacity=initial_capacity) test.model.reset_cache(
StaticKVCacheConfig(
max_batch_size=batch_size, max_cache_len=initial_capacity
)
)
# run test one case # run test one case
test.run( test.run(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment