Unverified Commit f67956fe authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #247 from InfiniTensor/issue/246

issue/246 change default kvcache blocksize to 256
parent 08090824
......@@ -86,7 +86,7 @@ class PagedKVCacheConfig final : public CacheConfig {
public:
PagedKVCacheConfig(
size_t num_blocks,
size_t block_size = 16);
size_t block_size = 256);
std::unique_ptr<CacheConfig> unique_copy() const override;
size_t num_blocks() const;
......
......@@ -37,7 +37,7 @@ inline void bind_cache(py::module &m) {
.def(
py::init<size_t, size_t>(),
py::arg("num_blocks"),
py::arg("block_size") = 16)
py::arg("block_size") = 256)
.def(
"num_blocks",
&infinilm::cache::PagedKVCacheConfig::num_blocks)
......
......@@ -22,6 +22,8 @@ DATA_TYPE_BYTES = {
"float32": 4,
}
_PAGED_KV_BLOCK_SIZE = 256
# BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128]
# INPUT_LENS = [32, 256, 1024, 4096]
# OUTPUT_LENS = [256, 1024, 4096]
......@@ -234,6 +236,12 @@ def get_args():
action="store_true",
help="use paged cache",
)
parser.add_argument(
"--paged_kv_block_size",
type=int,
default=256,
help="num tokens each kv block can hold",
)
parser.add_argument(
"--enable-graph",
action="store_true",
......@@ -399,6 +407,7 @@ if __name__ == "__main__":
"python examples/bench.py --nvidia --model=~/TinyLlama-1.1B-Chat-v1.0/ --batch-size=2 --tp=1 --input-len=50 --output-len=50"
)
sys.exit(1)
_PAGED_KV_BLOCK_SIZE = args.paged_kv_block_size
# -------------------------------------------------------- #
# 解析参数
# -------------------------------------------------------- #
......@@ -430,10 +439,14 @@ if __name__ == "__main__":
# 测试
# -------------------------------------------------------- #
if enable_paged_attn:
paged_kv_block_size = 16
paged_kv_block_size = _PAGED_KV_BLOCK_SIZE
max_num_blocks = max(
[
((c_["input_len"] + c_["output_len"] + 15) // 16) * c_["batch_size"]
(
(c_["input_len"] + c_["output_len"] + (paged_kv_block_size - 1))
// paged_kv_block_size
)
* c_["batch_size"]
for _, c_ in cases_dict.items()
]
)
......
......@@ -15,6 +15,8 @@ from packaging import version
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))
_PAGED_KV_BLOCK_SIZE = 256
def get_args():
parser = argparse.ArgumentParser(description="run Llama args")
......@@ -105,6 +107,14 @@ def get_args():
action="store_true",
help="use paged cache",
)
parser.add_argument(
"--paged_kv_block_size",
type=int,
default=256,
help="num tokens each kv block can hold",
)
parser.add_argument(
"--enable-graph",
action="store_true",
......@@ -225,7 +235,11 @@ def test(
batch_size = 1 if prompts is str else len(prompts)
max_total_tokens = max_new_tokens + len(input_ids_list[0])
cache_config = PagedKVCacheConfig(
num_blocks=((max_total_tokens + 15) // 16) * batch_size, block_size=16
num_blocks=(
(max_total_tokens + (_PAGED_KV_BLOCK_SIZE - 1)) // _PAGED_KV_BLOCK_SIZE
)
* batch_size,
block_size=_PAGED_KV_BLOCK_SIZE,
)
else:
batch_size = 1 if prompts is str else len(prompts)
......@@ -295,6 +309,7 @@ if __name__ == "__main__":
)
sys.exit(1)
prompts = [args.prompt for _ in range(args.batch_size)]
_PAGED_KV_BLOCK_SIZE = args.paged_kv_block_size
model_path = args.model_path
max_new_tokens = args.max_new_tokens
......
......@@ -17,7 +17,7 @@ class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig):
def __init__(
self,
num_blocks: int,
block_size: int = 16,
block_size: int = 256,
):
_infinilm.PagedKVCacheConfig.__init__(
self,
......
......@@ -34,15 +34,7 @@ class InferEngine(_infinilm.InferEngine):
if device is None:
device = infinicore.device()
# super().__init__(
# self.config,
# distributed_config._underlying,
# device._underlying.type,
# cache_config,
# enable_graph_compiling,
# )
super().__init__(
model_path,
distributed_config._underlying,
......@@ -109,7 +101,6 @@ class InferEngine(_infinilm.InferEngine):
generation_config,
*,
_measure_and_log_time=False,
paged_block_size=16,
):
if generation_config.eos_token_id is None:
eos_token_id = self.config.eos_token_id
......@@ -133,6 +124,7 @@ class InferEngine(_infinilm.InferEngine):
block_tables = None
max_blocks_per_batch = 0
if self.enable_paged_attn:
paged_block_size = self.get_cache_config().block_size()
max_blocks_per_batch = (
initial_seqlen + generation_config.max_new_tokens + paged_block_size - 1
) // paged_block_size
......
......@@ -63,8 +63,8 @@ class EngineConfig:
cache_type: str = "paged" # "paged" or "static"
max_batch_size: int = 16
max_tokens: int = 4096
num_blocks: int = 8 * 1024
block_size: int = 16
num_blocks: int = 512
block_size: int = 256
max_cache_len: int = 4096
temperature: float = 1.0
top_p: float = 0.8
......@@ -385,8 +385,8 @@ class LLM:
cache_type: str = "paged",
max_batch_size: int = 16,
max_tokens: int = 4096,
num_blocks: int = 8 * 1024,
block_size: int = 16,
num_blocks: int = 512,
block_size: int = 256,
max_cache_len: int = 4096,
temperature: float = 1.0,
top_p: float = 0.8,
......@@ -538,8 +538,8 @@ class AsyncLLMEngine:
cache_type: str = "paged",
max_batch_size: int = 16,
max_tokens: int = 512,
num_blocks: int = 8 * 1024,
block_size: int = 16,
num_blocks: int = 512,
block_size: int = 256,
max_cache_len: int = 4096,
temperature: float = 1.0,
top_p: float = 0.8,
......
......@@ -128,8 +128,8 @@ class Scheduler:
def __init__(
self,
max_batch_size: int = 16,
num_blocks: int = 8 * 1024,
block_size: int = 16,
num_blocks: int = 512,
block_size: int = 256,
):
self.waiting_queue = janus.Queue()
self.running_queue = janus.Queue()
......
......@@ -98,8 +98,8 @@ class InferenceServer:
cache_type: str = "paged",
max_tokens: int = 4096,
max_batch_size: int = 16,
num_blocks: int = 8 * 1024,
block_size: int = 16,
num_blocks: int = 512,
block_size: int = 256,
max_cache_len: int = 4096,
temperature: float = 1.0,
top_p: float = 0.8,
......@@ -555,13 +555,13 @@ def parse_args():
parser.add_argument(
"--num_blocks",
type=int,
default=8 * 1024,
default=512,
help="Number of blocks for KV cache (paged cache only)",
)
parser.add_argument(
"--block_size",
type=int,
default=16,
default=256,
help="Block size for KV cache (paged cache only)",
)
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment