Unverified Commit cdcbde5f authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Code structure refactor (#807)

parent 21e22b9e
......@@ -29,7 +29,7 @@ If OOM happens during prefill, try to decrease `--max-prefill-tokens`.
If OOM happens during decoding, try to decrease `--max-running-requests`.
You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.
### (Minor) Tune `--schedule-heuristic`
If you have many shared prefixes, use the default `--schedule-heuristic lpm`. `lpm` stands for longest prefix match.
### (Minor) Tune `--schedule-policy`
If you have many shared prefixes, use the default `--schedule-policy lpm`. `lpm` stands for longest prefix match.
When you have no shared prefixes at all or you always send the requests with the shared prefixes together,
you can try `--schedule-heuristic fcfs`. `fcfs` stands for first come first serve.
you can try `--schedule-policy fcfs`. `fcfs` stands for first come first serve.
# SGL API Components
from sglang.api import (
Runtime,
assistant,
......@@ -22,46 +23,46 @@ from sglang.api import (
video,
)
# Global Configurations
from sglang.global_config import global_config
# SGL Backends
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import LazyImport
from sglang.version import __version__
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
# public APIs management
# SGLang DSL APIs
__all__ = [
"global_config",
"Anthropic",
"LiteLLM",
"OpenAI",
"RuntimeEndpoint",
"VertexAI",
"function",
"Runtime",
"set_default_backend",
"assistant",
"assistant_begin",
"assistant_end",
"flush_cache",
"get_server_args",
"function",
"gen",
"gen_int",
"gen_string",
"get_server_args",
"image",
"video",
"select",
"set_default_backend",
"system",
"system_begin",
"system_end",
"user",
"assistant",
"user_begin",
"user_end",
"assistant_begin",
"assistant_end",
"system_begin",
"system_end",
"video",
]
# Global Configurations
from sglang.global_config import global_config
__all__ += ["global_config"]
from sglang.version import __version__
__all__ += ["__version__"]
# SGL Backends
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import LazyImport
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"]
......@@ -37,9 +37,9 @@ import torch
import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import suppress_other_loggers
......
......@@ -25,7 +25,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata
@dataclasses.dataclass
......
......@@ -22,7 +22,7 @@ from torch import nn
from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.controller.model_runner import (
from sglang.srt.model_executor.model_runner import (
ForwardMode,
InputMetadata,
global_server_args_dict,
......
......@@ -20,7 +20,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
from sglang.srt.managers.schedule_batch import global_server_args_dict
if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
......
......@@ -27,7 +27,7 @@ from enum import Enum, auto
import numpy as np
import zmq
from sglang.srt.managers.controller.manager_single import (
from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single,
)
from sglang.srt.managers.io_struct import (
......
......@@ -22,7 +22,7 @@ from typing import List
import zmq
from sglang.srt.managers.controller.tp_worker import (
from sglang.srt.managers.tp_worker import (
ModelTpServer,
broadcast_recv_input,
launch_tp_servers,
......
......@@ -25,8 +25,8 @@ import zmq
import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
......
......@@ -22,7 +22,7 @@ import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from sglang.srt.managers.controller.infer_batch import BaseFinishReason
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams
......
......@@ -13,47 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""Request scheduler heuristic."""
"""Request policy scheduler"""
import random
from collections import defaultdict
class ScheduleHeuristic:
class PolicyScheduler:
def __init__(
self,
schedule_heuristic,
policy,
max_running_seqs,
max_prefill_num_tokens,
max_total_num_tokens,
tree_cache,
):
if tree_cache.disable and schedule_heuristic == "lpm":
if tree_cache.disable and policy == "lpm":
# LMP is meaningless when the tree cache is disabled.
schedule_heuristic = "fcfs"
policy = "fcfs"
self.schedule_heuristic = schedule_heuristic
self.policy = policy
self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens
self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache
def get_priority_queue(self, waiting_queue):
if self.schedule_heuristic == "lpm":
if self.policy == "lpm":
# longest prefix match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue
elif self.schedule_heuristic == "fcfs":
elif self.policy == "fcfs":
# first come first serve
return waiting_queue
elif self.schedule_heuristic == "lof":
elif self.policy == "lof":
# longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue
elif self.schedule_heuristic == "random":
elif self.policy == "random":
random.shuffle(waiting_queue)
return waiting_queue
elif self.schedule_heuristic == "dfs-weight":
elif self.policy == "dfs-weight":
last_node_to_reqs = defaultdict(list)
for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req)
......@@ -70,7 +70,7 @@ class ScheduleHeuristic:
assert len(q) == len(waiting_queue)
return q
else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
raise ValueError(f"Unknown schedule_policy: {self.policy}")
def calc_weight(self, cur_node, node_to_weight):
for child in cur_node.children.values():
......
......@@ -28,8 +28,8 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.radix_cache import RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
......
......@@ -29,23 +29,23 @@ from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.controller.infer_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
ForwardMode,
Req,
)
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.managers.io_struct import (
AbortReq,
BatchTokenIDOut,
FlushCacheReq,
TokenizedGenerateReqInput,
)
from sglang.srt.managers.policy_scheduler import PolicyScheduler
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
ForwardMode,
Req,
)
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_int_token_logit_bias,
......@@ -74,7 +74,7 @@ class ModelTpServer:
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.schedule_policy = server_args.schedule_policy
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
# Chunked prefill
......@@ -150,8 +150,8 @@ class ModelTpServer:
disable=server_args.disable_radix_cache,
)
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = ScheduleHeuristic(
self.schedule_heuristic,
self.scheduler = PolicyScheduler(
self.schedule_policy,
self.max_running_requests,
self.max_prefill_tokens,
self.max_total_num_tokens,
......
......@@ -17,7 +17,7 @@ limitations under the License.
Flush the KV cache.
Usage:
python3 -m sglang.srt.flush_cache --url http://localhost:30000
python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000
"""
import argparse
......
......@@ -29,7 +29,7 @@ from sglang.srt.layers.logits_processor import (
LogitsMetadata,
LogitsProcessor,
)
from sglang.srt.managers.controller.infer_batch import (
from sglang.srt.managers.schedule_batch import (
Batch,
ForwardMode,
InputMetadata,
......
......@@ -40,13 +40,13 @@ from vllm.distributed import (
from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.managers.controller.infer_batch import (
from sglang.srt.managers.schedule_batch import (
Batch,
ForwardMode,
InputMetadata,
global_server_args_dict,
)
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
......@@ -273,7 +273,7 @@ class ModelRunner:
)
def init_cuda_graphs(self):
from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
self.cuda_graph_runner = None
......
......@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import InputMetadata
LoraConfig = None
......
......@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import InputMetadata
@torch.compile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment