Unverified Commit cdcbde5f authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Code structure refactor (#807)

parent 21e22b9e
...@@ -29,7 +29,7 @@ If OOM happens during prefill, try to decrease `--max-prefill-tokens`. ...@@ -29,7 +29,7 @@ If OOM happens during prefill, try to decrease `--max-prefill-tokens`.
If OOM happens during decoding, try to decrease `--max-running-requests`. If OOM happens during decoding, try to decrease `--max-running-requests`.
You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding. You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.
### (Minor) Tune `--schedule-heuristic` ### (Minor) Tune `--schedule-policy`
If you have many shared prefixes, use the default `--schedule-heuristic lpm`. `lpm` stands for longest prefix match. If you have many shared prefixes, use the default `--schedule-policy lpm`. `lpm` stands for longest prefix match.
When you have no shared prefixes at all or you always send the requests with the shared prefixes together, When you have no shared prefixes at all or you always send the requests with the shared prefixes together,
you can try `--schedule-heuristic fcfs`. `fcfs` stands for first come first serve. you can try `--schedule-policy fcfs`. `fcfs` stands for first come first serve.
# SGL API Components # SGL API Components
from sglang.api import ( from sglang.api import (
Runtime, Runtime,
assistant, assistant,
...@@ -22,46 +23,46 @@ from sglang.api import ( ...@@ -22,46 +23,46 @@ from sglang.api import (
video, video,
) )
# Global Configurations # SGLang DSL APIs
from sglang.global_config import global_config
# SGL Backends
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import LazyImport
from sglang.version import __version__
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
# public APIs management
__all__ = [ __all__ = [
"global_config",
"Anthropic",
"LiteLLM",
"OpenAI",
"RuntimeEndpoint",
"VertexAI",
"function",
"Runtime", "Runtime",
"set_default_backend", "assistant",
"assistant_begin",
"assistant_end",
"flush_cache", "flush_cache",
"get_server_args", "function",
"gen", "gen",
"gen_int", "gen_int",
"gen_string", "gen_string",
"get_server_args",
"image", "image",
"video",
"select", "select",
"set_default_backend",
"system", "system",
"system_begin",
"system_end",
"user", "user",
"assistant",
"user_begin", "user_begin",
"user_end", "user_end",
"assistant_begin", "video",
"assistant_end",
"system_begin",
"system_end",
] ]
# Global Configurations
from sglang.global_config import global_config
__all__ += ["global_config"]
from sglang.version import __version__
__all__ += ["__version__"]
# SGL Backends
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.utils import LazyImport
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"]
...@@ -37,9 +37,9 @@ import torch ...@@ -37,9 +37,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req from sglang.srt.managers.schedule_batch import Batch, ForwardMode, Req
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import suppress_other_loggers from sglang.srt.utils import suppress_other_loggers
......
...@@ -25,7 +25,7 @@ from vllm.distributed import ( ...@@ -25,7 +25,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata from sglang.srt.model_executor.model_runner import ForwardMode, InputMetadata
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -22,7 +22,7 @@ from torch import nn ...@@ -22,7 +22,7 @@ from torch import nn
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.controller.model_runner import ( from sglang.srt.model_executor.model_runner import (
ForwardMode, ForwardMode,
InputMetadata, InputMetadata,
global_server_args_dict, global_server_args_dict,
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.managers.controller.infer_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
if global_server_args_dict.get("attention_reduce_in_fp32", False): if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32 REDUCE_TRITON_TYPE = tl.float32
......
...@@ -27,7 +27,7 @@ from enum import Enum, auto ...@@ -27,7 +27,7 @@ from enum import Enum, auto
import numpy as np import numpy as np
import zmq import zmq
from sglang.srt.managers.controller.manager_single import ( from sglang.srt.managers.controller_single import (
start_controller_process as start_controller_process_single, start_controller_process as start_controller_process_single,
) )
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
......
...@@ -22,7 +22,7 @@ from typing import List ...@@ -22,7 +22,7 @@ from typing import List
import zmq import zmq
from sglang.srt.managers.controller.tp_worker import ( from sglang.srt.managers.tp_worker import (
ModelTpServer, ModelTpServer,
broadcast_recv_input, broadcast_recv_input,
launch_tp_servers, launch_tp_servers,
......
...@@ -25,8 +25,8 @@ import zmq ...@@ -25,8 +25,8 @@ import zmq
import zmq.asyncio import zmq.asyncio
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
......
...@@ -22,7 +22,7 @@ import uuid ...@@ -22,7 +22,7 @@ import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from sglang.srt.managers.controller.infer_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling_params import SamplingParams from sglang.srt.sampling_params import SamplingParams
......
...@@ -13,47 +13,47 @@ See the License for the specific language governing permissions and ...@@ -13,47 +13,47 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Request scheduler heuristic.""" """Request policy scheduler"""
import random import random
from collections import defaultdict from collections import defaultdict
class ScheduleHeuristic: class PolicyScheduler:
def __init__( def __init__(
self, self,
schedule_heuristic, policy,
max_running_seqs, max_running_seqs,
max_prefill_num_tokens, max_prefill_num_tokens,
max_total_num_tokens, max_total_num_tokens,
tree_cache, tree_cache,
): ):
if tree_cache.disable and schedule_heuristic == "lpm": if tree_cache.disable and policy == "lpm":
# LMP is meaningless when the tree cache is disabled. # LMP is meaningless when the tree cache is disabled.
schedule_heuristic = "fcfs" policy = "fcfs"
self.schedule_heuristic = schedule_heuristic self.policy = policy
self.max_running_seqs = max_running_seqs self.max_running_seqs = max_running_seqs
self.max_prefill_num_tokens = max_prefill_num_tokens self.max_prefill_num_tokens = max_prefill_num_tokens
self.max_total_num_tokens = max_total_num_tokens self.max_total_num_tokens = max_total_num_tokens
self.tree_cache = tree_cache self.tree_cache = tree_cache
def get_priority_queue(self, waiting_queue): def get_priority_queue(self, waiting_queue):
if self.schedule_heuristic == "lpm": if self.policy == "lpm":
# longest prefix match # longest prefix match
waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
return waiting_queue return waiting_queue
elif self.schedule_heuristic == "fcfs": elif self.policy == "fcfs":
# first come first serve # first come first serve
return waiting_queue return waiting_queue
elif self.schedule_heuristic == "lof": elif self.policy == "lof":
# longest output first # longest output first
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
return waiting_queue return waiting_queue
elif self.schedule_heuristic == "random": elif self.policy == "random":
random.shuffle(waiting_queue) random.shuffle(waiting_queue)
return waiting_queue return waiting_queue
elif self.schedule_heuristic == "dfs-weight": elif self.policy == "dfs-weight":
last_node_to_reqs = defaultdict(list) last_node_to_reqs = defaultdict(list)
for req in waiting_queue: for req in waiting_queue:
last_node_to_reqs[req.last_node].append(req) last_node_to_reqs[req.last_node].append(req)
...@@ -70,7 +70,7 @@ class ScheduleHeuristic: ...@@ -70,7 +70,7 @@ class ScheduleHeuristic:
assert len(q) == len(waiting_queue) assert len(q) == len(waiting_queue)
return q return q
else: else:
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}") raise ValueError(f"Unknown schedule_policy: {self.policy}")
def calc_weight(self, cur_node, node_to_weight): def calc_weight(self, cur_node, node_to_weight):
for child in cur_node.children.values(): for child in cur_node.children.values():
......
...@@ -28,8 +28,8 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs ...@@ -28,8 +28,8 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.mem_cache.radix_cache import RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
......
...@@ -29,23 +29,23 @@ from sglang.global_config import global_config ...@@ -29,23 +29,23 @@ from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.controller.infer_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
ForwardMode,
Req,
)
from sglang.srt.managers.controller.model_runner import ModelRunner
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchTokenIDOut, BatchTokenIDOut,
FlushCacheReq, FlushCacheReq,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
) )
from sglang.srt.managers.policy_scheduler import PolicyScheduler
from sglang.srt.managers.schedule_batch import (
FINISH_ABORT,
BaseFinishReason,
Batch,
ForwardMode,
Req,
)
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_int_token_logit_bias, get_int_token_logit_bias,
...@@ -74,7 +74,7 @@ class ModelTpServer: ...@@ -74,7 +74,7 @@ class ModelTpServer:
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_policy = server_args.schedule_policy
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
# Chunked prefill # Chunked prefill
...@@ -150,8 +150,8 @@ class ModelTpServer: ...@@ -150,8 +150,8 @@ class ModelTpServer:
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
) )
self.tree_cache_metrics = {"total": 0, "hit": 0} self.tree_cache_metrics = {"total": 0, "hit": 0}
self.scheduler = ScheduleHeuristic( self.scheduler = PolicyScheduler(
self.schedule_heuristic, self.schedule_policy,
self.max_running_requests, self.max_running_requests,
self.max_prefill_tokens, self.max_prefill_tokens,
self.max_total_num_tokens, self.max_total_num_tokens,
......
...@@ -17,7 +17,7 @@ limitations under the License. ...@@ -17,7 +17,7 @@ limitations under the License.
Flush the KV cache. Flush the KV cache.
Usage: Usage:
python3 -m sglang.srt.flush_cache --url http://localhost:30000 python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000
""" """
import argparse import argparse
......
...@@ -29,7 +29,7 @@ from sglang.srt.layers.logits_processor import ( ...@@ -29,7 +29,7 @@ from sglang.srt.layers.logits_processor import (
LogitsMetadata, LogitsMetadata,
LogitsProcessor, LogitsProcessor,
) )
from sglang.srt.managers.controller.infer_batch import ( from sglang.srt.managers.schedule_batch import (
Batch, Batch,
ForwardMode, ForwardMode,
InputMetadata, InputMetadata,
......
...@@ -40,13 +40,13 @@ from vllm.distributed import ( ...@@ -40,13 +40,13 @@ from vllm.distributed import (
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.managers.controller.infer_batch import ( from sglang.srt.managers.schedule_batch import (
Batch, Batch,
ForwardMode, ForwardMode,
InputMetadata, InputMetadata,
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
...@@ -273,7 +273,7 @@ class ModelRunner: ...@@ -273,7 +273,7 @@ class ModelRunner:
) )
def init_cuda_graphs(self): def init_cuda_graphs(self):
from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
self.cuda_graph_runner = None self.cuda_graph_runner = None
......
...@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig ...@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
LoraConfig = None LoraConfig = None
......
...@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
@torch.compile @torch.compile
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment