Unverified Commit 98111fbe authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Revert "Chunked prefill support" (#799)

parent 2ec39ab7
...@@ -38,24 +38,24 @@ class ScheduleHeuristic: ...@@ -38,24 +38,24 @@ class ScheduleHeuristic:
self.max_total_num_tokens = max_total_num_tokens self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache self.tree_cache = tree_cache
def get_priority_queue(self, waiting_queue): def get_priority_queue(self, forward_queue):
if self.schedule_heuristic == "lpm": if self.schedule_heuristic == "lpm":
# longest prefix match # longest prefix match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) forward_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue return forward_queue
elif self.schedule_heuristic == "fcfs": elif self.schedule_heuristic == "fcfs":
# first come first serve # first come first serve
return waiting_queue return forward_queue
elif self.schedule_heuristic == "lof": elif self.schedule_heuristic == "lof":
# longest output first # longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue return forward_queue
elif self.schedule_heuristic == "random": elif self.schedule_heuristic == "random":
random.shuffle(waiting_queue) random.shuffle(forward_queue)
return waiting_queue return forward_queue
elif self.schedule_heuristic == "dfs-weight": elif self.schedule_heuristic == "dfs-weight":
last_node_to_reqs = defaultdict(list) last_node_to_reqs = defaultdict(list)
for req in waiting_queue: for req in forward_queue:
last_node_to_reqs[req.last_node].append(req) last_node_to_reqs[req.last_node].append(req)
node_to_weight = defaultdict(int) node_to_weight = defaultdict(int)
...@@ -67,7 +67,7 @@ class ScheduleHeuristic: ...@@ -67,7 +67,7 @@ class ScheduleHeuristic:
self.get_dfs_priority( self.get_dfs_priority(
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
) )
assert len(q) == len(waiting_queue) assert len(q) == len(forward_queue)
return q return q
else: else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
......
...@@ -77,10 +77,6 @@ class ModelTpServer: ...@@ -77,10 +77,6 @@ class ModelTpServer:
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
# Chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
server_args.model_path, server_args.model_path,
...@@ -161,7 +157,7 @@ class ModelTpServer: ...@@ -161,7 +157,7 @@ class ModelTpServer:
self.token_to_kv_pool = self.model_runner.token_to_kv_pool self.token_to_kv_pool = self.model_runner.token_to_kv_pool
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.forward_queue: List[Req] = []
self.running_batch: Batch = None self.running_batch: Batch = None
self.out_pyobjs = [] self.out_pyobjs = []
self.decode_forward_ct = 0 self.decode_forward_ct = 0
...@@ -224,7 +220,6 @@ class ModelTpServer: ...@@ -224,7 +220,6 @@ class ModelTpServer:
# Run a new prefill batch # Run a new prefill batch
self.forward_prefill_batch(new_batch) self.forward_prefill_batch(new_batch)
self.cache_filled_batch(new_batch) self.cache_filled_batch(new_batch)
self.filter_out_inflight(new_batch)
if not new_batch.is_empty(): if not new_batch.is_empty():
if self.running_batch is None: if self.running_batch is None:
...@@ -266,7 +261,7 @@ class ModelTpServer: ...@@ -266,7 +261,7 @@ class ModelTpServer:
f"#token: {num_used}, " f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
f"gen throughput (token/s): {throughput:.2f}, " f"gen throughput (token/s): {throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}" f"#queue-req: {len(self.forward_queue)}"
) )
def check_memory(self): def check_memory(self):
...@@ -333,10 +328,9 @@ class ModelTpServer: ...@@ -333,10 +328,9 @@ class ModelTpServer:
), ),
self.max_req_input_len - 1 - len(req.origin_input_ids), self.max_req_input_len - 1 - len(req.origin_input_ids),
) )
self.waiting_queue.append(req) self.forward_queue.append(req)
def get_new_prefill_batch(self) -> Optional[Batch]: def get_new_prefill_batch(self) -> Optional[Batch]:
# TODO(lsyin): organize this function
running_bs = ( running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0 len(self.running_batch.reqs) if self.running_batch is not None else 0
) )
...@@ -344,7 +338,7 @@ class ModelTpServer: ...@@ -344,7 +338,7 @@ class ModelTpServer:
return return
# Compute matched prefix length # Compute matched prefix length
for req in self.waiting_queue: for req in self.forward_queue:
req.input_ids = req.origin_input_ids + req.output_ids req.input_ids = req.origin_input_ids + req.output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob: if req.return_logprob:
...@@ -354,7 +348,7 @@ class ModelTpServer: ...@@ -354,7 +348,7 @@ class ModelTpServer:
req.last_node = last_node req.last_node = last_node
# Get priority queue # Get priority queue
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue) self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
# Add requests if there is available space # Add requests if there is available space
can_run_list = [] can_run_list = []
...@@ -373,33 +367,7 @@ class ModelTpServer: ...@@ -373,33 +367,7 @@ class ModelTpServer:
] ]
) )
# Handle the current inflight request for req in self.forward_queue:
take_inflight = 0
if self.current_inflight_req:
take_inflight = 1
r = self.current_inflight_req
r.input_ids = r.origin_input_ids + r.output_ids
truncated = (
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
)
r.extend_input_len = min(
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
)
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
can_run_list.append(r)
if not truncated:
# Finish inflight
self.current_inflight_req = None
new_batch_total_tokens += (
r.extend_input_len + r.sampling_params.max_new_tokens
)
new_batch_input_tokens += r.extend_input_len
else:
new_batch_total_tokens += r.extend_input_len
new_batch_input_tokens += r.extend_input_len
for req in self.waiting_queue:
if req.return_logprob and req.normalized_prompt_logprob is None: if req.return_logprob and req.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob # Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2: if req.extend_input_len < 2:
...@@ -441,36 +409,11 @@ class ModelTpServer: ...@@ -441,36 +409,11 @@ class ModelTpServer:
break break
else: else:
# Add this request to the running batch # Add this request to the running batch
if ( can_run_list.append(req)
new_batch_input_tokens + req.extend_input_len new_batch_total_tokens += (
<= self.chunked_prefill_size req.extend_input_len + req.sampling_params.max_new_tokens
or ( )
req.return_logprob and req.normalized_prompt_logprob is None new_batch_input_tokens += req.extend_input_len
)
):
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.sampling_params.max_new_tokens
)
new_batch_input_tokens += req.extend_input_len
else:
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
if trunc_len <= 0:
# Undo locking
delta = self.tree_cache.dec_lock_ref(req.last_node)
available_size += delta
break
req.extend_input_len = trunc_len
req.input_ids = req.input_ids[
: len(req.prefix_indices) + req.extend_input_len
]
can_run_list.append(req)
self.current_inflight_req = req
new_batch_input_tokens += req.extend_input_len
new_batch_total_tokens += req.extend_input_len
break
else: else:
break break
...@@ -497,7 +440,7 @@ class ModelTpServer: ...@@ -497,7 +440,7 @@ class ModelTpServer:
f"#cached-token: {hit_tokens}, " f"#cached-token: {hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, " f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}" f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
) )
# Return the new batch # Return the new batch
...@@ -507,7 +450,7 @@ class ModelTpServer: ...@@ -507,7 +450,7 @@ class ModelTpServer:
self.token_to_kv_pool, self.token_to_kv_pool,
self.tree_cache, self.tree_cache,
) )
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list] self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
return new_batch return new_batch
def forward_prefill_batch(self, batch: Batch): def forward_prefill_batch(self, batch: Batch):
...@@ -539,10 +482,9 @@ class ModelTpServer: ...@@ -539,10 +482,9 @@ class ModelTpServer:
# Check finish conditions # Check finish conditions
pt = 0 pt = 0
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
if req is not self.current_inflight_req: req.completion_tokens_wo_jump_forward += 1
req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_ids[i])
req.output_ids.append(next_token_ids[i]) req.check_finished()
req.check_finished()
if req.return_logprob: if req.return_logprob:
self.add_logprob_return_values(i, req, pt, next_token_ids, output) self.add_logprob_return_values(i, req, pt, next_token_ids, output)
...@@ -603,7 +545,7 @@ class ModelTpServer: ...@@ -603,7 +545,7 @@ class ModelTpServer:
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs): for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req( new_prefix_indices, new_last_node = self.tree_cache.cache_req(
token_ids=tuple(req.input_ids), token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices), last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i], req_pool_idx=req_pool_indices_cpu[i],
del_in_memory_pool=False, del_in_memory_pool=False,
...@@ -611,10 +553,6 @@ class ModelTpServer: ...@@ -611,10 +553,6 @@ class ModelTpServer:
) )
req.prefix_indices, req.last_node = new_prefix_indices, new_last_node req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
if req is self.current_inflight_req:
# inflight request would get a new req idx
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
def forward_decode_batch(self, batch: Batch): def forward_decode_batch(self, batch: Batch):
# Check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem(): if not batch.check_decode_mem():
...@@ -628,7 +566,7 @@ class ModelTpServer: ...@@ -628,7 +566,7 @@ class ModelTpServer:
f"#retracted_reqs: {len(retracted_reqs)}, " f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
) )
self.waiting_queue.extend(retracted_reqs) self.forward_queue.extend(retracted_reqs)
else: else:
self.new_token_ratio = max( self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_decay, self.new_token_ratio - self.new_token_ratio_decay,
...@@ -638,7 +576,7 @@ class ModelTpServer: ...@@ -638,7 +576,7 @@ class ModelTpServer:
if not self.disable_regex_jump_forward: if not self.disable_regex_jump_forward:
# Check for jump-forward # Check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner) jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
self.waiting_queue.extend(jump_forward_reqs) self.forward_queue.extend(jump_forward_reqs)
if batch.is_empty(): if batch.is_empty():
return return
...@@ -773,18 +711,8 @@ class ModelTpServer: ...@@ -773,18 +711,8 @@ class ModelTpServer:
else: else:
batch.reqs = [] batch.reqs = []
def filter_out_inflight(self, batch: Batch):
# TODO(lsyin): reduce the overhead, make a special version for this
if self.current_inflight_req is None:
return
unfinished_indices = list(range(len(batch.reqs)))
unfinished_indices.remove(batch.reqs.index(self.current_inflight_req))
batch.filter_batch(unfinished_indices)
def flush_cache(self): def flush_cache(self):
if len(self.waiting_queue) == 0 and ( if len(self.forward_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0 self.running_batch is None or len(self.running_batch.reqs) == 0
): ):
self.tree_cache.reset() self.tree_cache.reset()
...@@ -797,20 +725,20 @@ class ModelTpServer: ...@@ -797,20 +725,20 @@ class ModelTpServer:
else: else:
warnings.warn( warnings.warn(
f"Cache not flushed because there are pending requests. " f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.forward_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
) )
def abort_request(self, recv_req): def abort_request(self, recv_req):
# Delete requests in the waiting queue # Delete requests in the waiting queue
to_del = None to_del = None
for i, req in enumerate(self.waiting_queue): for i, req in enumerate(self.forward_queue):
if req.rid == recv_req.rid: if req.rid == recv_req.rid:
to_del = i to_del = i
break break
if to_del is not None: if to_del is not None:
del self.waiting_queue[to_del] del self.forward_queue[to_del]
# Delete requests in the running batch # Delete requests in the running batch
if self.running_batch: if self.running_batch:
......
...@@ -45,7 +45,7 @@ class ReqToTokenPool: ...@@ -45,7 +45,7 @@ class ReqToTokenPool:
return select_index return select_index
def free(self, free_index): def free(self, free_index: int):
self.mem_state[free_index] = True self.mem_state[free_index] = True
if isinstance(free_index, (int,)): if isinstance(free_index, (int,)):
self.can_use_mem_size += 1 self.can_use_mem_size += 1
......
...@@ -175,39 +175,6 @@ def _set_torch_compile_config(): ...@@ -175,39 +175,6 @@ def _set_torch_compile_config():
torch._dynamo.config.accumulated_cache_size_limit = 256 torch._dynamo.config.accumulated_cache_size_limit = 256
def set_envs_and_config(server_args: ServerArgs):
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# Set ulimit
set_ulimit()
# Enable show time cost for debugging
if server_args.show_time_cost:
enable_show_time_cost()
# Disable disk cache
if server_args.disable_disk_cache:
disable_cache()
# Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
# Set torch compile config
if server_args.enable_torch_compile:
_set_torch_compile_config()
# Set global chat template
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
def launch_server( def launch_server(
server_args: ServerArgs, server_args: ServerArgs,
model_overide_args: Optional[dict] = None, model_overide_args: Optional[dict] = None,
...@@ -223,6 +190,16 @@ def launch_server( ...@@ -223,6 +190,16 @@ def launch_server(
format="%(message)s", format="%(message)s",
) )
# Set global environments
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
set_ulimit()
if server_args.show_time_cost:
enable_show_time_cost()
if server_args.disable_disk_cache:
disable_cache()
if not server_args.disable_flashinfer: if not server_args.disable_flashinfer:
assert_pkg_version( assert_pkg_version(
"flashinfer", "flashinfer",
...@@ -231,8 +208,14 @@ def launch_server( ...@@ -231,8 +208,14 @@ def launch_server(
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
) )
if server_args.tp_size * server_args.dp_size > 1:
set_envs_and_config(server_args) # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
if server_args.enable_torch_compile:
_set_torch_compile_config()
# Allocate ports # Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports = allocate_init_ports(
......
...@@ -65,9 +65,6 @@ class ServerArgs: ...@@ -65,9 +65,6 @@ class ServerArgs:
dp_size: int = 1 dp_size: int = 1
load_balance_method: str = "round_robin" load_balance_method: str = "round_robin"
# Chunked Prefill
chunked_prefill_size: Optional[int] = None
# Optimization/debug options # Optimization/debug options
disable_flashinfer: bool = False disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False disable_flashinfer_sampling: bool = False
...@@ -86,8 +83,6 @@ class ServerArgs: ...@@ -86,8 +83,6 @@ class ServerArgs:
node_rank: Optional[int] = None node_rank: Optional[int] = None
def __post_init__(self): def __post_init__(self):
if self.chunked_prefill_size is None:
self.chunked_prefill_size = int(10**9)
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
if self.mem_fraction_static is None: if self.mem_fraction_static is None:
...@@ -228,7 +223,7 @@ class ServerArgs: ...@@ -228,7 +223,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--max-num-reqs", "--max-num-reqs",
type=int, type=int,
default=ServerArgs.max_num_reqs, default=None,
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.", help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
) )
parser.add_argument( parser.add_argument(
...@@ -316,18 +311,10 @@ class ServerArgs: ...@@ -316,18 +311,10 @@ class ServerArgs:
help="The nccl init address of multi-node server.", help="The nccl init address of multi-node server.",
) )
parser.add_argument( parser.add_argument(
"--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes." "--nnodes", type=int, default=1, help="The number of nodes."
) )
parser.add_argument("--node-rank", type=int, help="The node rank.") parser.add_argument("--node-rank", type=int, help="The node rank.")
# Chunked prefill
parser.add_argument(
"--chunked-prefill-size",
type=int,
default=ServerArgs.chunked_prefill_size,
help="The size of the chunked prefill.",
)
# Optimization/debug options # Optimization/debug options
parser.add_argument( parser.add_argument(
"--disable-flashinfer", "--disable-flashinfer",
...@@ -406,10 +393,6 @@ class ServerArgs: ...@@ -406,10 +393,6 @@ class ServerArgs:
self.dp_size > 1 and self.node_rank is not None self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported" ), "multi-node data parallel is not supported"
assert not (
self.chunked_prefill_size is not None and self.disable_radix_cache
), "chunked prefill is not supported with radix cache disabled currently"
@dataclasses.dataclass @dataclasses.dataclass
class PortArgs: class PortArgs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment