Unverified Commit 476584cb authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Increase the capacity of the memory pool (#643)

parent abd5385a
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
import bisect import bisect
import torch import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from sglang.global_config import global_config
from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.controller.infer_batch import ( from sglang.srt.managers.controller.infer_batch import (
Batch, Batch,
...@@ -74,9 +75,6 @@ class CudaGraphRunner: ...@@ -74,9 +75,6 @@ class CudaGraphRunner:
self.flashinfer_handlers[bs] = flashinfer_handler self.flashinfer_handlers[bs] = flashinfer_handler
def capture_one_batch_size(self, bs): def capture_one_batch_size(self, bs):
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
stream = self.stream stream = self.stream
......
...@@ -325,6 +325,11 @@ class Batch: ...@@ -325,6 +325,11 @@ class Batch:
seq_lens = [] seq_lens = []
req_pool_indices = self.req_to_token_pool.alloc(bs) req_pool_indices = self.req_to_token_pool.alloc(bs)
if req_pool_indices is None:
raise RuntimeError("Out of memory. "
"Please set a smaller number for `--max-running-requests`.")
req_pool_indices_cpu = req_pool_indices.cpu().numpy() req_pool_indices_cpu = req_pool_indices.cpu().numpy()
for i in range(bs): for i in range(bs):
flatten_input_ids.extend(input_ids[i]) flatten_input_ids.extend(input_ids[i])
......
...@@ -9,6 +9,12 @@ from typing import Optional, Type ...@@ -9,6 +9,12 @@ from typing import Optional, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.config import DeviceConfig, LoadConfig from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import ( from vllm.distributed import (
...@@ -162,7 +168,7 @@ class ModelRunner: ...@@ -162,7 +168,7 @@ class ModelRunner:
) )
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
int(self.max_total_num_tokens / self.model_config.context_len * 256), max(int(self.max_total_num_tokens / self.model_config.context_len * 512), 2048),
self.model_config.context_len + 8, self.model_config.context_len + 8,
) )
self.token_to_kv_pool = TokenToKVPool( self.token_to_kv_pool = TokenToKVPool(
...@@ -193,13 +199,6 @@ class ModelRunner: ...@@ -193,13 +199,6 @@ class ModelRunner:
self.flashinfer_decode_wrapper = None self.flashinfer_decode_wrapper = None
return return
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels( if not _grouped_size_compiled_for_decode_kernels(
self.model_config.num_attention_heads // self.tp_size, self.model_config.num_attention_heads // self.tp_size,
self.model_config.get_num_kv_heads(self.tp_size), self.model_config.get_num_kv_heads(self.tp_size),
......
...@@ -44,7 +44,7 @@ class ReqToTokenPool: ...@@ -44,7 +44,7 @@ class ReqToTokenPool:
class TokenToKVPool: class TokenToKVPool:
"""A memory pool that maps a token to its kv cache locations""" """A memory pool that maps a token to its kv cache locations"""
def __init__(self, size, dtype, head_num, head_dim, layer_num): def __init__(self, size: int, dtype: torch.dtype, head_num: int, head_dim: int, layer_num: int):
self.size = size self.size = size
# We also add one slot. This slot is used for writing dummy output from padded tokens. # We also add one slot. This slot is used for writing dummy output from padded tokens.
...@@ -63,16 +63,16 @@ class TokenToKVPool: ...@@ -63,16 +63,16 @@ class TokenToKVPool:
self.can_use_mem_size = self.size self.can_use_mem_size = self.size
self.clear() self.clear()
def get_key_buffer(self, layer_id): def get_key_buffer(self, layer_id: int):
return self.kv_data[layer_id][:, 0] return self.kv_data[layer_id][:, 0]
def get_value_buffer(self, layer_id): def get_value_buffer(self, layer_id: int):
return self.kv_data[layer_id][:, 1] return self.kv_data[layer_id][:, 1]
def available_size(self): def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer) return self.can_use_mem_size + len(self.prefetch_buffer)
def alloc(self, need_size): def alloc(self, need_size: int):
buffer_len = len(self.prefetch_buffer) buffer_len = len(self.prefetch_buffer)
if need_size <= buffer_len: if need_size <= buffer_len:
select_index = self.prefetch_buffer[:need_size] select_index = self.prefetch_buffer[:need_size]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment