Unverified Commit f86c1e61 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Move scheduler code from tp_worker.py to scheduler.py (#1538)

parent acaffd23
......@@ -167,9 +167,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
assert len(input_ids[i]) > bench_args.cut_len
tmp_input_ids = input_ids[i][: bench_args.cut_len]
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
req = Req(
rid=i,
origin_input_text=prompts[i],
origin_input_ids=tmp_input_ids,
sampling_params=sampling_params,
)
req.prefix_indices = []
req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
......@@ -199,9 +203,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
reqs = []
for i in range(len(input_ids)):
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
req = Req(
rid=i,
origin_input_text="",
origin_input_ids=list(input_ids[i]),
sampling_params=sampling_params,
)
req.prefix_indices = []
req.sampling_params = sampling_params
req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
reqs.append(req)
......
......@@ -18,7 +18,6 @@ The definition of objects transfered between different
processes (TokenizerManager, DetokenizerManager, Controller).
"""
import copy
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
......@@ -53,12 +52,12 @@ class GenerateReqInput:
stream: bool = False
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None
is_single: bool = True
# LoRA related
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
# Whether it is a single request or a batch request
is_single: bool = True
def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
......@@ -307,10 +306,6 @@ class BatchTokenIDOut:
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
def __post_init__(self):
# deepcopy meta_info to avoid modification in place
self.meta_info = copy.deepcopy(self.meta_info)
@dataclass
class BatchStrOut:
......
......@@ -31,6 +31,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
......@@ -143,6 +144,7 @@ class Req:
rid: str,
origin_input_text: str,
origin_input_ids: Tuple[int],
sampling_params: SamplingParams,
lora_path: Optional[str] = None,
):
# Input and output info
......@@ -152,6 +154,8 @@ class Req:
self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
self.sampling_params = sampling_params
self.lora_path = lora_path
# Memory info
......@@ -160,6 +164,7 @@ class Req:
# Check finish
self.tokenizer = None
self.finished_reason = None
self.stream = False
# For incremental decoding
# ----- | --------- read_ids -------|
......@@ -187,10 +192,6 @@ class Req:
self.extend_input_len = 0
self.last_node = None
# Sampling parameters
self.sampling_params = None
self.stream = False
# Logprobs (arguments)
self.return_logprob = False
self.logprob_start_len = 0
......
This diff is collapsed.
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Request policy scheduler"""
"""Request scheduler policy"""
import os
import random
......@@ -32,7 +32,7 @@ from sglang.srt.mem_cache.radix_cache import TreeNode
CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
class PolicyScheduler:
class SchedulerPolicy:
def __init__(self, policy: str, tree_cache: BasePrefixCache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
......
This diff is collapsed.
......@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
def __init__(self, size: int, max_context_len: int):
def __init__(self, size: int, max_context_len: int, device: str):
self.size = size
self.free_slots = list(range(size))
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda"
(size, max_context_len), dtype=torch.int32, device=device
)
def alloc(self, need_size: int) -> List[int]:
......
......@@ -87,6 +87,7 @@ class ModelRunner:
self.model_config.hf_config.architectures
)
# Model-specific adjustment
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
......@@ -94,6 +95,13 @@ class ModelRunner:
logger.info("MLA optimization is tunred on. Use triton backend.")
self.server_args.attention_backend = "triton"
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
global_server_args_dict.update(
{
"attention_backend": server_args.attention_backend,
......@@ -104,14 +112,6 @@ class ModelRunner:
}
)
# Model-specific adjustment
if self.is_multimodal_model:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = None
server_args.mem_fraction_static *= 0.95
# Init componnets
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
......@@ -400,8 +400,7 @@ class ModelRunner:
)
self.req_to_token_pool = ReqToTokenPool(
max_num_reqs + 1,
self.model_config.context_len + 4,
max_num_reqs + 1, self.model_config.context_len + 4, device="cuda"
)
if (
self.model_config.attention_arch == AttentionArch.MLA
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment