Unverified Commit f724f1f1 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

PrefillAdder abstraction (#968)

parent 6db27f7b
...@@ -17,6 +17,9 @@ limitations under the License. ...@@ -17,6 +17,9 @@ limitations under the License.
import random import random
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
class PolicyScheduler: class PolicyScheduler:
...@@ -83,3 +86,122 @@ class PolicyScheduler: ...@@ -83,3 +86,122 @@ class PolicyScheduler:
for child in childs: for child in childs:
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
q.extend(last_node_to_reqs[cur_node]) q.extend(last_node_to_reqs[cur_node])
class PrefillAdder:
def __init__(
self,
tree_cache,
rem_total_tokens,
rem_input_tokens,
rem_chunk_tokens,
):
self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens
self.rem_chunk_tokens = rem_chunk_tokens
self.can_run_list = []
self.new_inflight_req = None
self.log_hit_tokens = 0
self.log_input_tokens = 0
def no_remaining_tokens(self):
return (
self.rem_total_tokens <= 0
or self.rem_input_tokens <= 0
or (
self.rem_chunk_tokens <= 0
if self.rem_chunk_tokens is not None
else False
)
)
def remove_running_tokens(
self, running_batch: ScheduleBatch, new_token_ratio: float
):
self.rem_total_tokens -= sum(
[
(r.sampling_params.max_new_tokens - len(r.output_ids)) * new_token_ratio
for r in running_batch.reqs
]
)
def _prefill_one_req(
self, prefix_len: int, extend_input_len: int, max_new_tokens: int
):
self.rem_total_tokens -= extend_input_len + max_new_tokens
self.rem_input_tokens -= extend_input_len
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= extend_input_len
self.log_hit_tokens += prefix_len
self.log_input_tokens += extend_input_len
def add_inflight_req(self, req: Req):
req.input_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = len(req.input_ids) - len(req.prefix_indices)
truncated = req.extend_input_len > self.rem_chunk_tokens
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
req.input_ids = req.input_ids[: len(req.prefix_indices) + req.extend_input_len]
self.can_run_list.append(req)
self._prefill_one_req(
len(req.prefix_indices),
req.extend_input_len,
req.sampling_params.max_new_tokens if not truncated else 0,
)
# Return if chunked prefill not finished
return req if truncated else None
@contextmanager
def _lock_node(self, last_node):
try:
delta = self.tree_cache.inc_lock_ref(last_node)
self.rem_total_tokens += delta
yield None
finally:
delta = self.tree_cache.dec_lock_ref(last_node)
self.rem_total_tokens += delta
def add_one_req(self, req: Req):
total_tokens = req.extend_input_len + req.sampling_params.max_new_tokens
input_tokens = req.extend_input_len
prefix_len = len(req.prefix_indices)
if total_tokens >= self.rem_total_tokens:
return False
if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
return False
with self._lock_node(req.last_node):
if total_tokens > self.rem_total_tokens:
return False
if (
self.rem_chunk_tokens is None
or input_tokens <= self.rem_chunk_tokens
or (req.return_logprob and req.normalized_prompt_logprob is None)
):
# Non-chunked prefill
self.can_run_list.append(req)
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(
prefix_len, input_tokens, req.sampling_params.max_new_tokens
)
else:
# Chunked prefill
trunc_len = self.rem_chunk_tokens
if trunc_len == 0:
return False
req.extend_input_len = trunc_len
req.input_ids = req.input_ids[: len(req.prefix_indices) + trunc_len]
self.can_run_list.append(req)
self.new_inflight_req = req
self.tree_cache.inc_lock_ref(req.last_node)
self._prefill_one_req(prefix_len, trunc_len, 0)
return True
...@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.policy_scheduler import PolicyScheduler from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
FINISH_ABORT, FINISH_ABORT,
BaseFinishReason, BaseFinishReason,
...@@ -377,151 +377,57 @@ class ModelTpServer: ...@@ -377,151 +377,57 @@ class ModelTpServer:
# Get priority queue # Get priority queue
self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue) self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
# Add requests if there is available space adder = PrefillAdder(
can_run_list = [] self.tree_cache,
new_batch_total_tokens = 0 self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
new_batch_input_tokens = 0 self.max_prefill_tokens,
self.chunked_prefill_size,
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
if self.running_batch:
available_size -= sum(
[
(r.sampling_params.max_new_tokens - len(r.output_ids))
* self.new_token_ratio
for r in self.running_batch.reqs
]
)
# Handle the current inflight request if self.running_batch is not None:
take_inflight = 0 adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
if self.current_inflight_req:
take_inflight = 1 has_inflight = self.current_inflight_req is not None
r = self.current_inflight_req if self.current_inflight_req is not None:
r.input_ids = r.origin_input_ids + r.output_ids self.current_inflight_req = adder.add_inflight_req(
truncated = ( self.current_inflight_req
len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
)
r.extend_input_len = min(
len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
) )
r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
can_run_list.append(r)
if not truncated:
# Finish inflight
self.current_inflight_req = None
new_batch_total_tokens += (
r.extend_input_len + r.sampling_params.max_new_tokens
)
new_batch_input_tokens += r.extend_input_len
else:
new_batch_total_tokens += r.extend_input_len
new_batch_input_tokens += r.extend_input_len
for req in self.waiting_queue: for req in self.waiting_queue:
if req.return_logprob and req.normalized_prompt_logprob is None: res = adder.add_one_req(req)
# Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2:
delta = 2 - req.extend_input_len
req.extend_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None:
req.image_offset += delta
if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
# Need at least one token to compute logits
req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None:
req.image_offset += 1
if ( if (
req.extend_input_len not res
+ req.sampling_params.max_new_tokens or adder.no_remaining_tokens()
+ new_batch_total_tokens or running_bs + len(adder.can_run_list) >= self.max_running_requests
< available_size
and (
req.extend_input_len + new_batch_input_tokens
<= self.max_prefill_tokens
or len(can_run_list) == 0
)
): ):
delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta
if not (
req.extend_input_len
+ req.sampling_params.max_new_tokens
+ new_batch_total_tokens
< available_size
):
# Undo locking
delta = self.tree_cache.dec_lock_ref(req.last_node)
available_size += delta
break
else:
# Add this request to the running batch
if (
self.chunked_prefill_size is None
or (
new_batch_input_tokens + req.extend_input_len
<= self.chunked_prefill_size
)
or (
req.return_logprob and req.normalized_prompt_logprob is None
)
):
can_run_list.append(req)
new_batch_total_tokens += (
req.extend_input_len + req.sampling_params.max_new_tokens
)
new_batch_input_tokens += req.extend_input_len
else:
trunc_len = self.chunked_prefill_size - new_batch_input_tokens
if trunc_len <= 0:
# Undo locking
delta = self.tree_cache.dec_lock_ref(req.last_node)
available_size += delta
break
req.extend_input_len = trunc_len
req.input_ids = req.input_ids[
: len(req.prefix_indices) + req.extend_input_len
]
can_run_list.append(req)
self.current_inflight_req = req
new_batch_input_tokens += req.extend_input_len
new_batch_total_tokens += req.extend_input_len
break
else:
break break
if running_bs + len(can_run_list) >= self.max_running_requests: can_run_list = adder.can_run_list
break
if adder.new_inflight_req is not None:
assert self.current_inflight_req is None
self.current_inflight_req = adder.new_inflight_req
if len(can_run_list) == 0: if len(can_run_list) == 0:
return None return None
# Print stats # Print stats
if self.tp_rank == 0: if self.tp_rank == 0:
hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
self.tree_cache_metrics["total"] += ( self.tree_cache_metrics["total"] += (
hit_tokens + new_batch_input_tokens adder.log_input_tokens + adder.log_hit_tokens
) / 10**9 ) / 10**9
self.tree_cache_metrics["hit"] += hit_tokens / 10**9 self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
tree_cache_hit_rate = ( tree_cache_hit_rate = (
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
) )
logger.info( logger.info(
f"[gpu={self.gpu_id}] Prefill batch. " f"[gpu={self.gpu_id}] Prefill batch. "
f"#new-seq: {len(can_run_list)}, " f"#new-seq: {len(can_run_list)}, "
f"#new-token: {new_batch_input_tokens}, " f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {hit_tokens}, " f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, " f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}" f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
) )
# Return the new batch # Return the new batch
......
...@@ -130,7 +130,7 @@ class ModelRunner: ...@@ -130,7 +130,7 @@ class ModelRunner:
server_args.max_total_tokens, server_args.max_total_tokens,
) )
self.init_cublas() self.init_cublas()
self.init_flash_infer() self.init_flashinfer()
# Capture cuda graphs # Capture cuda graphs
self.init_cuda_graphs() self.init_cuda_graphs()
...@@ -287,7 +287,7 @@ class ModelRunner: ...@@ -287,7 +287,7 @@ class ModelRunner:
c = a @ b c = a @ b
return c return c
def init_flash_infer(self): def init_flashinfer(self):
if self.server_args.disable_flashinfer: if self.server_args.disable_flashinfer:
self.flashinfer_prefill_wrapper_ragged = None self.flashinfer_prefill_wrapper_ragged = None
self.flashinfer_prefill_wrapper_paged = None self.flashinfer_prefill_wrapper_paged = None
......
...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf ...@@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
......
...@@ -46,8 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -46,8 +46,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module):
) )
return logits return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment