Unverified Commit f67956fe authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #247 from InfiniTensor/issue/246

issue/246 change default kvcache blocksize to 256
parent 08090824
...@@ -86,7 +86,7 @@ class PagedKVCacheConfig final : public CacheConfig { ...@@ -86,7 +86,7 @@ class PagedKVCacheConfig final : public CacheConfig {
public: public:
PagedKVCacheConfig( PagedKVCacheConfig(
size_t num_blocks, size_t num_blocks,
size_t block_size = 16); size_t block_size = 256);
std::unique_ptr<CacheConfig> unique_copy() const override; std::unique_ptr<CacheConfig> unique_copy() const override;
size_t num_blocks() const; size_t num_blocks() const;
......
...@@ -37,7 +37,7 @@ inline void bind_cache(py::module &m) { ...@@ -37,7 +37,7 @@ inline void bind_cache(py::module &m) {
.def( .def(
py::init<size_t, size_t>(), py::init<size_t, size_t>(),
py::arg("num_blocks"), py::arg("num_blocks"),
py::arg("block_size") = 16) py::arg("block_size") = 256)
.def( .def(
"num_blocks", "num_blocks",
&infinilm::cache::PagedKVCacheConfig::num_blocks) &infinilm::cache::PagedKVCacheConfig::num_blocks)
......
...@@ -22,6 +22,8 @@ DATA_TYPE_BYTES = { ...@@ -22,6 +22,8 @@ DATA_TYPE_BYTES = {
"float32": 4, "float32": 4,
} }
_PAGED_KV_BLOCK_SIZE = 256
# BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128] # BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128]
# INPUT_LENS = [32, 256, 1024, 4096] # INPUT_LENS = [32, 256, 1024, 4096]
# OUTPUT_LENS = [256, 1024, 4096] # OUTPUT_LENS = [256, 1024, 4096]
...@@ -234,6 +236,12 @@ def get_args(): ...@@ -234,6 +236,12 @@ def get_args():
action="store_true", action="store_true",
help="use paged cache", help="use paged cache",
) )
parser.add_argument(
"--paged_kv_block_size",
type=int,
default=256,
help="num tokens each kv block can hold",
)
parser.add_argument( parser.add_argument(
"--enable-graph", "--enable-graph",
action="store_true", action="store_true",
...@@ -399,6 +407,7 @@ if __name__ == "__main__": ...@@ -399,6 +407,7 @@ if __name__ == "__main__":
"python examples/bench.py --nvidia --model=~/TinyLlama-1.1B-Chat-v1.0/ --batch-size=2 --tp=1 --input-len=50 --output-len=50" "python examples/bench.py --nvidia --model=~/TinyLlama-1.1B-Chat-v1.0/ --batch-size=2 --tp=1 --input-len=50 --output-len=50"
) )
sys.exit(1) sys.exit(1)
_PAGED_KV_BLOCK_SIZE = args.paged_kv_block_size
# -------------------------------------------------------- # # -------------------------------------------------------- #
# 解析参数 # 解析参数
# -------------------------------------------------------- # # -------------------------------------------------------- #
...@@ -430,10 +439,14 @@ if __name__ == "__main__": ...@@ -430,10 +439,14 @@ if __name__ == "__main__":
# 测试 # 测试
# -------------------------------------------------------- # # -------------------------------------------------------- #
if enable_paged_attn: if enable_paged_attn:
paged_kv_block_size = 16 paged_kv_block_size = _PAGED_KV_BLOCK_SIZE
max_num_blocks = max( max_num_blocks = max(
[ [
((c_["input_len"] + c_["output_len"] + 15) // 16) * c_["batch_size"] (
(c_["input_len"] + c_["output_len"] + (paged_kv_block_size - 1))
// paged_kv_block_size
)
* c_["batch_size"]
for _, c_ in cases_dict.items() for _, c_ in cases_dict.items()
] ]
) )
......
...@@ -15,6 +15,8 @@ from packaging import version ...@@ -15,6 +15,8 @@ from packaging import version
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))
_PAGED_KV_BLOCK_SIZE = 256
def get_args(): def get_args():
parser = argparse.ArgumentParser(description="run Llama args") parser = argparse.ArgumentParser(description="run Llama args")
...@@ -105,6 +107,14 @@ def get_args(): ...@@ -105,6 +107,14 @@ def get_args():
action="store_true", action="store_true",
help="use paged cache", help="use paged cache",
) )
parser.add_argument(
"--paged_kv_block_size",
type=int,
default=256,
help="num tokens each kv block can hold",
)
parser.add_argument( parser.add_argument(
"--enable-graph", "--enable-graph",
action="store_true", action="store_true",
...@@ -225,7 +235,11 @@ def test( ...@@ -225,7 +235,11 @@ def test(
batch_size = 1 if prompts is str else len(prompts) batch_size = 1 if prompts is str else len(prompts)
max_total_tokens = max_new_tokens + len(input_ids_list[0]) max_total_tokens = max_new_tokens + len(input_ids_list[0])
cache_config = PagedKVCacheConfig( cache_config = PagedKVCacheConfig(
num_blocks=((max_total_tokens + 15) // 16) * batch_size, block_size=16 num_blocks=(
(max_total_tokens + (_PAGED_KV_BLOCK_SIZE - 1)) // _PAGED_KV_BLOCK_SIZE
)
* batch_size,
block_size=_PAGED_KV_BLOCK_SIZE,
) )
else: else:
batch_size = 1 if prompts is str else len(prompts) batch_size = 1 if prompts is str else len(prompts)
...@@ -295,6 +309,7 @@ if __name__ == "__main__": ...@@ -295,6 +309,7 @@ if __name__ == "__main__":
) )
sys.exit(1) sys.exit(1)
prompts = [args.prompt for _ in range(args.batch_size)] prompts = [args.prompt for _ in range(args.batch_size)]
_PAGED_KV_BLOCK_SIZE = args.paged_kv_block_size
model_path = args.model_path model_path = args.model_path
max_new_tokens = args.max_new_tokens max_new_tokens = args.max_new_tokens
......
...@@ -17,7 +17,7 @@ class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig): ...@@ -17,7 +17,7 @@ class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig):
def __init__( def __init__(
self, self,
num_blocks: int, num_blocks: int,
block_size: int = 16, block_size: int = 256,
): ):
_infinilm.PagedKVCacheConfig.__init__( _infinilm.PagedKVCacheConfig.__init__(
self, self,
......
...@@ -34,15 +34,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -34,15 +34,7 @@ class InferEngine(_infinilm.InferEngine):
if device is None: if device is None:
device = infinicore.device() device = infinicore.device()
# super().__init__(
# self.config,
# distributed_config._underlying,
# device._underlying.type,
# cache_config,
# enable_graph_compiling,
# )
super().__init__( super().__init__(
model_path, model_path,
distributed_config._underlying, distributed_config._underlying,
...@@ -109,7 +101,6 @@ class InferEngine(_infinilm.InferEngine): ...@@ -109,7 +101,6 @@ class InferEngine(_infinilm.InferEngine):
generation_config, generation_config,
*, *,
_measure_and_log_time=False, _measure_and_log_time=False,
paged_block_size=16,
): ):
if generation_config.eos_token_id is None: if generation_config.eos_token_id is None:
eos_token_id = self.config.eos_token_id eos_token_id = self.config.eos_token_id
...@@ -133,6 +124,7 @@ class InferEngine(_infinilm.InferEngine): ...@@ -133,6 +124,7 @@ class InferEngine(_infinilm.InferEngine):
block_tables = None block_tables = None
max_blocks_per_batch = 0 max_blocks_per_batch = 0
if self.enable_paged_attn: if self.enable_paged_attn:
paged_block_size = self.get_cache_config().block_size()
max_blocks_per_batch = ( max_blocks_per_batch = (
initial_seqlen + generation_config.max_new_tokens + paged_block_size - 1 initial_seqlen + generation_config.max_new_tokens + paged_block_size - 1
) // paged_block_size ) // paged_block_size
......
...@@ -63,8 +63,8 @@ class EngineConfig: ...@@ -63,8 +63,8 @@ class EngineConfig:
cache_type: str = "paged" # "paged" or "static" cache_type: str = "paged" # "paged" or "static"
max_batch_size: int = 16 max_batch_size: int = 16
max_tokens: int = 4096 max_tokens: int = 4096
num_blocks: int = 8 * 1024 num_blocks: int = 512
block_size: int = 16 block_size: int = 256
max_cache_len: int = 4096 max_cache_len: int = 4096
temperature: float = 1.0 temperature: float = 1.0
top_p: float = 0.8 top_p: float = 0.8
...@@ -385,8 +385,8 @@ class LLM: ...@@ -385,8 +385,8 @@ class LLM:
cache_type: str = "paged", cache_type: str = "paged",
max_batch_size: int = 16, max_batch_size: int = 16,
max_tokens: int = 4096, max_tokens: int = 4096,
num_blocks: int = 8 * 1024, num_blocks: int = 512,
block_size: int = 16, block_size: int = 256,
max_cache_len: int = 4096, max_cache_len: int = 4096,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 0.8, top_p: float = 0.8,
...@@ -538,8 +538,8 @@ class AsyncLLMEngine: ...@@ -538,8 +538,8 @@ class AsyncLLMEngine:
cache_type: str = "paged", cache_type: str = "paged",
max_batch_size: int = 16, max_batch_size: int = 16,
max_tokens: int = 512, max_tokens: int = 512,
num_blocks: int = 8 * 1024, num_blocks: int = 512,
block_size: int = 16, block_size: int = 256,
max_cache_len: int = 4096, max_cache_len: int = 4096,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 0.8, top_p: float = 0.8,
......
...@@ -128,8 +128,8 @@ class Scheduler: ...@@ -128,8 +128,8 @@ class Scheduler:
def __init__( def __init__(
self, self,
max_batch_size: int = 16, max_batch_size: int = 16,
num_blocks: int = 8 * 1024, num_blocks: int = 512,
block_size: int = 16, block_size: int = 256,
): ):
self.waiting_queue = janus.Queue() self.waiting_queue = janus.Queue()
self.running_queue = janus.Queue() self.running_queue = janus.Queue()
......
...@@ -98,8 +98,8 @@ class InferenceServer: ...@@ -98,8 +98,8 @@ class InferenceServer:
cache_type: str = "paged", cache_type: str = "paged",
max_tokens: int = 4096, max_tokens: int = 4096,
max_batch_size: int = 16, max_batch_size: int = 16,
num_blocks: int = 8 * 1024, num_blocks: int = 512,
block_size: int = 16, block_size: int = 256,
max_cache_len: int = 4096, max_cache_len: int = 4096,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 0.8, top_p: float = 0.8,
...@@ -555,13 +555,13 @@ def parse_args(): ...@@ -555,13 +555,13 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--num_blocks", "--num_blocks",
type=int, type=int,
default=8 * 1024, default=512,
help="Number of blocks for KV cache (paged cache only)", help="Number of blocks for KV cache (paged cache only)",
) )
parser.add_argument( parser.add_argument(
"--block_size", "--block_size",
type=int, type=int,
default=16, default=256,
help="Block size for KV cache (paged cache only)", help="Block size for KV cache (paged cache only)",
) )
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment