Unverified Commit 2ec39ab7 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Chunked prefill support (#797)

parent 8f6274c8
......@@ -38,24 +38,24 @@ class ScheduleHeuristic:
self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache
def get_priority_queue(self, forward_queue):
def get_priority_queue(self, waiting_queue):
if self.schedule_heuristic == "lpm":
# longest prefix match
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
return forward_queue
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue
elif self.schedule_heuristic == "fcfs":
# first come first serve
return forward_queue
return waiting_queue
elif self.schedule_heuristic == "lof":
# longest output first
forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return forward_queue
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue
elif self.schedule_heuristic == "random":
random.shuffle(forward_queue)
return forward_queue
random.shuffle(waiting_queue)
return waiting_queue
elif self.schedule_heuristic == "dfs-weight":
last_node_to_reqs = defaultdict(list)
for req in forward_queue:
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)
node_to_weight = defaultdict(int)
......@@ -67,7 +67,7 @@ class ScheduleHeuristic:
self.get_dfs_priority(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
)
assert len(q) == len(forward_queue)
assert len(q) == len(waiting_queue)
return q
else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
......
......@@ -77,6 +77,10 @@ class ModelTpServer:
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
# Chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
......@@ -157,7 +161,7 @@ class ModelTpServer:
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
# Init running status
self.forward_queue: List[Req] = []
self.waiting_queue: List[Req] = []
self.running_batch: Batch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
......@@ -220,6 +224,7 @@ class ModelTpServer:
# Run a new prefill batch
self.forward_prefill_batch(new_batch)
self.cache_filled_batch(new_batch)
self.filter_out_inflight(new_batch)
if not new_batch.is_empty():
if self.running_batch is None:
......@@ -261,7 +266,7 @@ class ModelTpServer:
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.forward_queue)}"
f"#queue-req: {len(self.waiting_queue)}"
)
def check_memory(self):
......@@ -328,9 +333,10 @@ class ModelTpServer:
),
self.max_req_input_len - 1 - len(req.origin_input_ids),
)
self.forward_queue.append(req)
self.waiting_queue.append(req)
def get_new_prefill_batch(self) -> Optional[Batch]:
# TODO(lsyin): organize this function
running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
......@@ -338,7 +344,7 @@ class ModelTpServer:
return
# Compute matched prefix length
for req in self.forward_queue:
for req in self.waiting_queue:
req.input_ids = req.origin_input_ids + req.output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob:
......@@ -348,7 +354,7 @@ class ModelTpServer:
req.last_node = last_node
# Get priority queue
self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
# Add requests if there is available space
can_run_list = []
......@@ -367,7 +373,33 @@ class ModelTpServer:
]
)
for req in self.forward_queue:
# Handle the current inflight request
take_inflight = 0
if self.current_inflight_req:
take_inflight = 1
r = self.current_inflight_req
r.input_ids = r.origin_input_ids + r.output_ids
truncated = (
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
)
r.extend_input_len = min(
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
)
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
can_run_list.append(r)
if not truncated:
# Finish inflight
self.current_inflight_req = None
new_batch_total_tokens += (
r.extend_input_len + r.sampling_params.max_new_tokens
)
new_batch_input_tokens += r.extend_input_len
else:
new_batch_total_tokens += r.extend_input_len
new_batch_input_tokens += r.extend_input_len
for req in self.waiting_queue:
if req.return_logprob and req.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2:
......@@ -409,11 +441,36 @@ class ModelTpServer:
break
else:
# Add this request to the running batch
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.sampling_params.max_new_tokens
)
new_batch_input_tokens += req.extend_input_len
if (
new_batch_input_tokens + req.extend_input_len
<= self.chunked_prefill_size
or (
req.return_logprob and req.normalized_prompt_logprob is None
)
):
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.sampling_params.max_new_tokens
)
new_batch_input_tokens += req.extend_input_len
else:
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
if trunc_len <= 0:
# Undo locking
delta = self.tree_cache.dec_lock_ref(req.last_node)
available_size += delta
break
req.extend_input_len = trunc_len
req.input_ids = req.input_ids[
: len(req.prefix_indices) + req.extend_input_len
]
can_run_list.append(req)
self.current_inflight_req = req
new_batch_input_tokens += req.extend_input_len
new_batch_total_tokens += req.extend_input_len
break
else:
break
......@@ -440,7 +497,7 @@ class ModelTpServer:
f"#cached-token: {hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
)
# Return the new batch
......@@ -450,7 +507,7 @@ class ModelTpServer:
self.token_to_kv_pool,
self.tree_cache,
)
self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
return new_batch
def forward_prefill_batch(self, batch: Batch):
......@@ -482,9 +539,10 @@ class ModelTpServer:
# Check finish conditions
pt = 0
for i, req in enumerate(batch.reqs):
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req is not self.current_inflight_req:
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_ids[i])
req.check_finished()
if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
......@@ -545,7 +603,7 @@ class ModelTpServer:
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
token_ids=tuple(req.input_ids),
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False,
......@@ -553,6 +611,10 @@ class ModelTpServer:
)
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
if req is self.current_inflight_req:
# inflight request would get a new req idx
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
def forward_decode_batch(self, batch: Batch):
# Check if decode out of memory
if not batch.check_decode_mem():
......@@ -566,7 +628,7 @@ class ModelTpServer:
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self.forward_queue.extend(retracted_reqs)
self.waiting_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay,
......@@ -576,7 +638,7 @@ class ModelTpServer:
if not self.disable_regex_jump_forward:
# Check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
self.forward_queue.extend(jump_forward_reqs)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
return
......@@ -711,8 +773,18 @@ class ModelTpServer:
else:
batch.reqs = []
def filter_out_inflight(self, batch: Batch):
# TODO(lsyin): reduce the overhead, make a special version for this
if self.current_inflight_req is None:
return
unfinished_indices = list(range(len(batch.reqs)))
unfinished_indices.remove(batch.reqs.index(self.current_inflight_req))
batch.filter_batch(unfinished_indices)
def flush_cache(self):
if len(self.forward_queue) == 0 and (
if len(self.waiting_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
self.tree_cache.reset()
......@@ -725,20 +797,20 @@ class ModelTpServer:
else:
warnings.warn(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.forward_queue)}, "
f"#queue-req: {len(self.waiting_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
def abort_request(self, recv_req):
# Delete requests in the waiting queue
to_del = None
for i, req in enumerate(self.forward_queue):
for i, req in enumerate(self.waiting_queue):
if req.rid == recv_req.rid:
to_del = i
break
if to_del is not None:
del self.forward_queue[to_del]
del self.waiting_queue[to_del]
# Delete requests in the running batch
if self.running_batch:
......
......@@ -45,7 +45,7 @@ class ReqToTokenPool:
return select_index
def free(self, free_index: int):
def free(self, free_index):
self.mem_state[free_index] = True
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
......
......@@ -175,6 +175,39 @@ def _set_torch_compile_config():
torch._dynamo.config.accumulated_cache_size_limit = 256
def set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# Set ulimit
set_ulimit()
# Enable show time cost for debugging
if server_args.show_time_cost:
enable_show_time_cost()
# Disable disk cache
if server_args.disable_disk_cache:
disable_cache()
# Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
# Set torch compile config
if server_args.enable_torch_compile:
_set_torch_compile_config()
# Set global chat template
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
def launch_server(
server_args: ServerArgs,
model_overide_args: Optional[dict] = None,
......@@ -190,16 +223,6 @@ def launch_server(
format="%(message)s",
)
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
set_ulimit()
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
disable_cache()
if not server_args.disable_flashinfer:
assert_pkg_version(
"flashinfer",
......@@ -208,14 +231,8 @@ def launch_server(
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
if server_args.enable_torch_compile:
_set_torch_compile_config()
set_envs_and_config(server_args)
# Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports(
......
......@@ -65,6 +65,9 @@ class ServerArgs:
dp_size: int = 1
load_balance_method: str = "round_robin"
# Chunked Prefill
chunked_prefill_size: Optional[int] = None
# Optimization/debug options
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
......@@ -83,6 +86,8 @@ class ServerArgs:
node_rank: Optional[int] = None
def __post_init__(self):
if self.chunked_prefill_size is None:
self.chunked_prefill_size = int(10**9)
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
if self.mem_fraction_static is None:
......@@ -223,7 +228,7 @@ class ServerArgs:
parser.add_argument(
"--max-num-reqs",
type=int,
default=None,
default=ServerArgs.max_num_reqs,
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
)
parser.add_argument(
......@@ -311,10 +316,18 @@ class ServerArgs:
help="The nccl init address of multi-node server.",
)
parser.add_argument(
"--nnodes", type=int, default=1, help="The number of nodes."
"--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
)
parser.add_argument("--node-rank", type=int, help="The node rank.")
# Chunked prefill
parser.add_argument(
"--chunked-prefill-size",
type=int,
default=ServerArgs.chunked_prefill_size,
help="The size of the chunked prefill.",
)
# Optimization/debug options
parser.add_argument(
"--disable-flashinfer",
......@@ -393,6 +406,10 @@ class ServerArgs:
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"
assert not (
self.chunked_prefill_size is not None and self.disable_radix_cache
), "chunked prefill is not supported with radix cache disabled currently"
@dataclasses.dataclass
class PortArgs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment