Unverified Commit f86c1e61 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Move scheduler code from tp_worker.py to scheduler.py (#1538)

parent acaffd23
...@@ -167,9 +167,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer): ...@@ -167,9 +167,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
assert len(input_ids[i]) > bench_args.cut_len assert len(input_ids[i]) > bench_args.cut_len
tmp_input_ids = input_ids[i][: bench_args.cut_len] tmp_input_ids = input_ids[i][: bench_args.cut_len]
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids) req = Req(
rid=i,
origin_input_text=prompts[i],
origin_input_ids=tmp_input_ids,
sampling_params=sampling_params,
)
req.prefix_indices = [] req.prefix_indices = []
req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req) reqs.append(req)
...@@ -199,9 +203,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): ...@@ -199,9 +203,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
reqs = [] reqs = []
for i in range(len(input_ids)): for i in range(len(input_ids)):
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i])) req = Req(
rid=i,
origin_input_text="",
origin_input_ids=list(input_ids[i]),
sampling_params=sampling_params,
)
req.prefix_indices = [] req.prefix_indices = []
req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req) reqs.append(req)
......
...@@ -18,7 +18,6 @@ The definition of objects transfered between different ...@@ -18,7 +18,6 @@ The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller). processes (TokenizerManager, DetokenizerManager, Controller).
""" """
import copy
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
...@@ -53,12 +52,12 @@ class GenerateReqInput: ...@@ -53,12 +52,12 @@ class GenerateReqInput:
stream: bool = False stream: bool = False
# The modalities of the image data [image, multi-images, video] # The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None modalities: Optional[List[str]] = None
is_single: bool = True
# LoRA related # LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self): def post_init(self):
if (self.text is None and self.input_ids is None) or ( if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
...@@ -307,10 +306,6 @@ class BatchTokenIDOut: ...@@ -307,10 +306,6 @@ class BatchTokenIDOut:
meta_info: List[Dict] meta_info: List[Dict]
finished_reason: List[BaseFinishReason] finished_reason: List[BaseFinishReason]
def __post_init__(self):
# deepcopy meta_info to avoid modification in place
self.meta_info = copy.deepcopy(self.meta_info)
@dataclass @dataclass
class BatchStrOut: class BatchStrOut:
......
...@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache ...@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...@@ -143,6 +144,7 @@ class Req: ...@@ -143,6 +144,7 @@ class Req:
rid: str, rid: str,
origin_input_text: str, origin_input_text: str,
origin_input_ids: Tuple[int], origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
): ):
# Input and output info # Input and output info
...@@ -152,6 +154,8 @@ class Req: ...@@ -152,6 +154,8 @@ class Req:
self.origin_input_ids = origin_input_ids self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.sampling_params = sampling_params
self.lora_path = lora_path self.lora_path = lora_path
# Memory info # Memory info
...@@ -160,6 +164,7 @@ class Req: ...@@ -160,6 +164,7 @@ class Req:
# Check finish # Check finish
self.tokenizer = None self.tokenizer = None
self.finished_reason = None self.finished_reason = None
self.stream = False
# For incremental decoding # For incremental decoding
# ----- | --------- read_ids -------| # ----- | --------- read_ids -------|
...@@ -187,10 +192,6 @@ class Req: ...@@ -187,10 +192,6 @@ class Req:
self.extend_input_len = 0 self.extend_input_len = 0
self.last_node = None self.last_node = None
# Sampling parameters
self.sampling_params = None
self.stream = False
# Logprobs (arguments) # Logprobs (arguments)
self.return_logprob = False self.return_logprob = False
self.logprob_start_len = 0 self.logprob_start_len = 0
......
This diff is collapsed.
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Request policy scheduler""" """Request scheduler policy"""
import os import os
import random import random
...@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode ...@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096")) CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
class PolicyScheduler: class SchedulerPolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache): def __init__(self, policy: str, tree_cache: BasePrefixCache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]: if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled. # LPM and DFS-weight is meaningless when the tree cache is disabled.
......
This diff is collapsed.
...@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__) ...@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
class ReqToTokenPool: class ReqToTokenPool:
"""A memory pool that maps a request to its token locations.""" """A memory pool that maps a request to its token locations."""
def __init__(self, size: int, max_context_len: int): def __init__(self, size: int, max_context_len: int, device: str):
self.size = size self.size = size
self.free_slots = list(range(size)) self.free_slots = list(range(size))
self.req_to_token = torch.empty( self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda" (size, max_context_len), dtype=torch.int32, device=device
) )
def alloc(self, need_size: int) -> List[int]: def alloc(self, need_size: int) -> List[int]:
......
...@@ -87,6 +87,7 @@ class ModelRunner: ...@@ -87,6 +87,7 @@ class ModelRunner:
self.model_config.hf_config.architectures self.model_config.hf_config.architectures
) )
# Model-specific adjustment
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla and not self.server_args.disable_mla
...@@ -94,6 +95,13 @@ class ModelRunner: ...@@ -94,6 +95,13 @@ class ModelRunner:
logger.info("MLA optimization is tunred on. Use triton backend.") logger.info("MLA optimization is tunred on. Use triton backend.")
self.server_args.attention_backend = "triton" self.server_args.attention_backend = "triton"
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
global_server_args_dict.update( global_server_args_dict.update(
{ {
"attention_backend": server_args.attention_backend, "attention_backend": server_args.attention_backend,
...@@ -104,14 +112,6 @@ class ModelRunner: ...@@ -104,14 +112,6 @@ class ModelRunner:
} }
) )
# Model-specific adjustment
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
# Init componnets # Init componnets
min_per_gpu_memory = self.init_torch_distributed() min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler() self.sampler = Sampler()
...@@ -400,8 +400,7 @@ class ModelRunner: ...@@ -400,8 +400,7 @@ class ModelRunner:
) )
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
max_num_reqs + 1, max_num_reqs + 1, self.model_config.context_len + 4, device="cuda"
self.model_config.context_len + 4,
) )
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment