Unverified Commit ed2e313e authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up server_args, triton cache manager (#8332)

parent f8260f25
...@@ -71,7 +71,6 @@ from sglang.srt.utils import ( ...@@ -71,7 +71,6 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
kill_process_tree, kill_process_tree,
launch_dummy_health_check_server, launch_dummy_health_check_server,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer, prepare_model_and_tokenizer,
set_prometheus_multiproc_dir, set_prometheus_multiproc_dir,
set_ulimit, set_ulimit,
...@@ -637,11 +636,6 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -637,11 +636,6 @@ def _set_envs_and_config(server_args: ServerArgs):
# Set ulimit # Set ulimit
set_ulimit() set_ulimit()
# Fix triton bugs
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()
# Check flashinfer version # Check flashinfer version
if server_args.attention_backend == "flashinfer": if server_args.attention_backend == "flashinfer":
assert_pkg_version( assert_pkg_version(
......
...@@ -107,6 +107,8 @@ from sglang.version import __version__ ...@@ -107,6 +107,8 @@ from sglang.version import __version__
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
# Store global states # Store global states
@dataclasses.dataclass @dataclasses.dataclass
...@@ -212,9 +214,6 @@ async def validate_json_request(raw_request: Request): ...@@ -212,9 +214,6 @@ async def validate_json_request(raw_request: Request):
) )
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
##### Native API endpoints ##### ##### Native API endpoints #####
...@@ -807,6 +806,24 @@ async def retrieve_model(model: str): ...@@ -807,6 +806,24 @@ async def retrieve_model(model: str):
) )
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
async def v1_score_request(request: ScoringRequest, raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
return await raw_request.app.state.openai_serving_score.handle_request(
request, raw_request
)
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)
## SageMaker API ## SageMaker API
@app.get("/ping") @app.get("/ping")
async def sagemaker_health() -> Response: async def sagemaker_health() -> Response:
...@@ -852,24 +869,6 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque ...@@ -852,24 +869,6 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
return ORJSONResponse({"predictions": ret}) return ORJSONResponse({"predictions": ret})
@app.post("/v1/score", dependencies=[Depends(validate_json_request)])
async def v1_score_request(request: ScoringRequest, raw_request: Request):
"""Endpoint for the decoder-only scoring API. See Engine.score() for detailed documentation."""
return await raw_request.app.state.openai_serving_score.handle_request(
request, raw_request
)
@app.api_route(
"/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
)
async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
"""Endpoint for reranking documents based on query relevance."""
return await raw_request.app.state.openai_serving_rerank.handle_request(
request, raw_request
)
def _create_error_response(e): def _create_error_response(e):
return ORJSONResponse( return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
...@@ -916,15 +915,6 @@ def launch_server( ...@@ -916,15 +915,6 @@ def launch_server(
add_prometheus_middleware(app) add_prometheus_middleware(app)
enable_func_timer() enable_func_timer()
image_token_text = None
if (
tokenizer_manager.image_token_id is not None
and not server_args.skip_tokenizer_init
):
image_token_text = tokenizer_manager.tokenizer.decode(
[tokenizer_manager.image_token_id]
)
# Send a warmup request - we will create the thread launch it # Send a warmup request - we will create the thread launch it
# in the lifespan after all other warmups have fired. # in the lifespan after all other warmups have fired.
warmup_thread = threading.Thread( warmup_thread = threading.Thread(
...@@ -932,7 +922,6 @@ def launch_server( ...@@ -932,7 +922,6 @@ def launch_server(
args=( args=(
server_args, server_args,
pipe_finish_writer, pipe_finish_writer,
image_token_text,
launch_callback, launch_callback,
), ),
) )
...@@ -1066,7 +1055,6 @@ def _execute_server_warmup( ...@@ -1066,7 +1055,6 @@ def _execute_server_warmup(
def _wait_and_warmup( def _wait_and_warmup(
server_args: ServerArgs, server_args: ServerArgs,
pipe_finish_writer: Optional[multiprocessing.connection.Connection], pipe_finish_writer: Optional[multiprocessing.connection.Connection],
image_token_text: str,
launch_callback: Optional[Callable[[], None]] = None, launch_callback: Optional[Callable[[], None]] = None,
): ):
if not server_args.skip_server_warmup: if not server_args.skip_server_warmup:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
from __future__ import annotations from __future__ import annotations
import math import math
from typing import TYPE_CHECKING, Callable, NamedTuple, Optional from typing import Callable, NamedTuple, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -39,10 +39,10 @@ from sglang.srt.utils import ( ...@@ -39,10 +39,10 @@ from sglang.srt.utils import (
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu() _is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import moe_fused_gate from sgl_kernel import moe_fused_gate
...@@ -54,7 +54,6 @@ if _use_aiter: ...@@ -54,7 +54,6 @@ if _use_aiter:
from aiter import biased_grouped_topk as aiter_biased_grouped_topk from aiter import biased_grouped_topk as aiter_biased_grouped_topk
except ImportError: except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
if _is_npu: if _is_npu:
import torch_npu import torch_npu
......
...@@ -653,6 +653,9 @@ class Scheduler( ...@@ -653,6 +653,9 @@ class Scheduler(
) )
) )
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
init_embedding_cache(embedding_cache_size * 1024 * 1024)
def init_profier(self): def init_profier(self):
self.torch_profiler = None self.torch_profiler = None
self.torch_profiler_output_dir: Optional[str] = None self.torch_profiler_output_dir: Optional[str] = None
...@@ -2895,9 +2898,9 @@ def run_scheduler_process( ...@@ -2895,9 +2898,9 @@ def run_scheduler_process(
prefix += f" PP{pp_rank}" prefix += f" PP{pp_rank}"
# Config the process # Config the process
kill_itself_when_parent_died()
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}") setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
faulthandler.enable() faulthandler.enable()
kill_itself_when_parent_died()
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
...@@ -2912,10 +2915,6 @@ def run_scheduler_process( ...@@ -2912,10 +2915,6 @@ def run_scheduler_process(
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
embedding_cache_size = 100
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
init_embedding_cache(embedding_cache_size * 1024 * 1024)
# Create a scheduler and run the event loop # Create a scheduler and run the event loop
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank) scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
...@@ -2926,8 +2925,8 @@ def run_scheduler_process( ...@@ -2926,8 +2925,8 @@ def run_scheduler_process(
"max_req_input_len": scheduler.max_req_input_len, "max_req_input_len": scheduler.max_req_input_len,
} }
) )
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
if disaggregation_mode == DisaggregationMode.NULL: if disaggregation_mode == DisaggregationMode.NULL:
if server_args.pp_size > 1: if server_args.pp_size > 1:
scheduler.event_loop_pp() scheduler.event_loop_pp()
......
...@@ -74,8 +74,6 @@ class ForwardMode(IntEnum): ...@@ -74,8 +74,6 @@ class ForwardMode(IntEnum):
MIXED = auto() MIXED = auto()
# No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated. # No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated.
IDLE = auto() IDLE = auto()
# Split Prefill for PD multiplexing
SPLIT_PREFILL = auto()
# Used in speculative decoding: verify a batch in the target model. # Used in speculative decoding: verify a batch in the target model.
TARGET_VERIFY = auto() TARGET_VERIFY = auto()
...@@ -86,6 +84,9 @@ class ForwardMode(IntEnum): ...@@ -86,6 +84,9 @@ class ForwardMode(IntEnum):
# It is now used for triggering the sampling_info_done event for the first prefill batch. # It is now used for triggering the sampling_info_done event for the first prefill batch.
DUMMY_FIRST = auto() DUMMY_FIRST = auto()
# Split Prefill for PD multiplexing
SPLIT_PREFILL = auto()
def is_prefill(self): def is_prefill(self):
return self.is_extend() return self.is_extend()
...@@ -103,12 +104,12 @@ class ForwardMode(IntEnum): ...@@ -103,12 +104,12 @@ class ForwardMode(IntEnum):
def is_mixed(self): def is_mixed(self):
return self == ForwardMode.MIXED return self == ForwardMode.MIXED
def is_split_prefill(self):
return self == ForwardMode.SPLIT_PREFILL
def is_idle(self): def is_idle(self):
return self == ForwardMode.IDLE return self == ForwardMode.IDLE
def is_decode_or_idle(self):
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
def is_target_verify(self): def is_target_verify(self):
return self == ForwardMode.TARGET_VERIFY return self == ForwardMode.TARGET_VERIFY
...@@ -132,8 +133,8 @@ class ForwardMode(IntEnum): ...@@ -132,8 +133,8 @@ class ForwardMode(IntEnum):
def is_dummy_first(self): def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST return self == ForwardMode.DUMMY_FIRST
def is_decode_or_idle(self): def is_split_prefill(self):
return self == ForwardMode.DECODE or self == ForwardMode.IDLE return self == ForwardMode.SPLIT_PREFILL
@total_ordering @total_ordering
......
...@@ -109,7 +109,6 @@ from sglang.srt.utils import ( ...@@ -109,7 +109,6 @@ from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_cpu_ids_by_node, get_cpu_ids_by_node,
init_custom_process_group, init_custom_process_group,
is_cuda,
is_fa3_default_architecture, is_fa3_default_architecture,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
......
...@@ -80,7 +80,7 @@ class ServerArgs: ...@@ -80,7 +80,7 @@ class ServerArgs:
schedule_policy: str = "fcfs" schedule_policy: str = "fcfs"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0 cpu_offload_gb: int = 0
page_size: int = 1 page_size: Optional[int] = None
hybrid_kvcache_ratio: Optional[float] = None hybrid_kvcache_ratio: Optional[float] = None
swa_full_tokens_ratio: float = 0.8 swa_full_tokens_ratio: float = 0.8
disable_hybrid_swa_memory: bool = False disable_hybrid_swa_memory: bool = False
...@@ -266,31 +266,20 @@ class ServerArgs: ...@@ -266,31 +266,20 @@ class ServerArgs:
def __post_init__(self): def __post_init__(self):
# Expert parallelism # Expert parallelism
# We put it here first due to some internal ckpt conversation issues.
if self.enable_ep_moe: if self.enable_ep_moe:
self.ep_size = self.tp_size self.ep_size = self.tp_size
logger.warning( logger.warning(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
) )
if self.enable_flashinfer_moe:
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
os.environ["TRTLLM_ENABLE_PDL"] = "1"
self.disable_shared_experts_fusion = True
logger.warning(
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
)
# Set missing default values # Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
if self.device is None:
self.device = get_device()
if self.served_model_name is None: if self.served_model_name is None:
self.served_model_name = self.model_path self.served_model_name = self.model_path
if self.device is None:
self.device = get_device()
if self.random_seed is None: if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30) self.random_seed = random.randint(0, 1 << 30)
...@@ -359,7 +348,6 @@ class ServerArgs: ...@@ -359,7 +348,6 @@ class ServerArgs:
self.chunked_prefill_size = 16384 self.chunked_prefill_size = 16384
else: else:
self.chunked_prefill_size = 4096 self.chunked_prefill_size = 4096
assert self.chunked_prefill_size % self.page_size == 0
# Set cuda graph max batch size # Set cuda graph max batch size
if self.cuda_graph_max_bs is None: if self.cuda_graph_max_bs is None:
...@@ -410,6 +398,14 @@ class ServerArgs: ...@@ -410,6 +398,14 @@ class ServerArgs:
) )
self.page_size = 128 self.page_size = 128
# Set page size
if self.page_size is None:
self.page_size = 1
# AMD-specific Triton attention KV splits default number
if is_hip():
self.triton_attention_num_kv_splits = 16
# Choose grammar backend # Choose grammar backend
if self.grammar_backend is None: if self.grammar_backend is None:
self.grammar_backend = "xgrammar" self.grammar_backend = "xgrammar"
...@@ -431,6 +427,17 @@ class ServerArgs: ...@@ -431,6 +427,17 @@ class ServerArgs:
self.enable_dp_attention self.enable_dp_attention
), "Please enable dp attention when setting enable_dp_lm_head. " ), "Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel
if self.enable_flashinfer_moe:
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
os.environ["TRTLLM_ENABLE_PDL"] = "1"
self.disable_shared_experts_fusion = True
logger.warning(
f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
)
# DeepEP MoE # DeepEP MoE
if self.enable_deepep_moe: if self.enable_deepep_moe:
if self.deepep_mode == "normal": if self.deepep_mode == "normal":
...@@ -502,14 +509,6 @@ class ServerArgs: ...@@ -502,14 +509,6 @@ class ServerArgs:
logger.warning( logger.warning(
"DeepSeek MTP does not require setting speculative_draft_model_path." "DeepSeek MTP does not require setting speculative_draft_model_path."
) )
elif "Llama4" in model_arch:
# TODO: remove this after Llama4 supports in other backends
if self.attention_backend != "fa3":
self.attention_backend = "fa3"
logger.warning(
"Llama4 requires using fa3 attention backend. "
"Attention backend is automatically set to fa3."
)
# Auto choose parameters # Auto choose parameters
if self.speculative_num_steps is None: if self.speculative_num_steps is None:
...@@ -542,12 +541,11 @@ class ServerArgs: ...@@ -542,12 +541,11 @@ class ServerArgs:
) and check_gguf_file(self.model_path): ) and check_gguf_file(self.model_path):
self.quantization = self.load_format = "gguf" self.quantization = self.load_format = "gguf"
# Model loading
if is_remote_url(self.model_path): if is_remote_url(self.model_path):
self.load_format = "remote" self.load_format = "remote"
if self.custom_weight_loader is None:
# AMD-specific Triton attention KV splits default number self.custom_weight_loader = []
if is_hip():
self.triton_attention_num_kv_splits = 16
# PD disaggregation # PD disaggregation
if self.disaggregation_mode == "decode": if self.disaggregation_mode == "decode":
...@@ -572,6 +570,7 @@ class ServerArgs: ...@@ -572,6 +570,7 @@ class ServerArgs:
self.disable_cuda_graph = True self.disable_cuda_graph = True
logger.warning("Cuda graph is disabled for prefill server") logger.warning("Cuda graph is disabled for prefill server")
# Propagate env vars
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = ( os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
"1" if self.enable_torch_compile else "0" "1" if self.enable_torch_compile else "0"
) )
...@@ -580,9 +579,6 @@ class ServerArgs: ...@@ -580,9 +579,6 @@ class ServerArgs:
"1" if self.disable_outlines_disk_cache else "0" "1" if self.disable_outlines_disk_cache else "0"
) )
if self.custom_weight_loader is None:
self.custom_weight_loader = []
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and tokenizer # Model and tokenizer
...@@ -1227,6 +1223,13 @@ class ServerArgs: ...@@ -1227,6 +1223,13 @@ class ServerArgs:
default=ServerArgs.grammar_backend, default=ServerArgs.grammar_backend,
help="Choose the backend for grammar-guided decoding.", help="Choose the backend for grammar-guided decoding.",
) )
parser.add_argument(
"--mm-attention-backend",
type=str,
choices=["sdpa", "fa3", "triton_attn"],
default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.",
)
# Speculative decoding # Speculative decoding
parser.add_argument( parser.add_argument(
...@@ -1276,13 +1279,6 @@ class ServerArgs: ...@@ -1276,13 +1279,6 @@ class ServerArgs:
help="The path of the draft model's small vocab table.", help="The path of the draft model's small vocab table.",
default=ServerArgs.speculative_token_map, default=ServerArgs.speculative_token_map,
) )
parser.add_argument(
"--mm-attention-backend",
type=str,
choices=["sdpa", "fa3", "triton_attn"],
default=ServerArgs.mm_attention_backend,
help="Set multimodal attention backend.",
)
# Expert parallelism # Expert parallelism
parser.add_argument( parser.add_argument(
...@@ -1530,11 +1526,6 @@ class ServerArgs: ...@@ -1530,11 +1526,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.", help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
) )
parser.add_argument(
"--disable-overlap-cg-plan",
action="store_true",
help="Disable the overlap optimization for cudagraph preparation in eagle verify.",
)
parser.add_argument( parser.add_argument(
"--enable-mixed-chunk", "--enable-mixed-chunk",
action="store_true", action="store_true",
...@@ -1792,11 +1783,11 @@ class ServerArgs: ...@@ -1792,11 +1783,11 @@ class ServerArgs:
return hf_config return hf_config
def check_server_args(self): def check_server_args(self):
# Check parallel size constraints
assert ( assert (
self.tp_size * self.pp_size self.tp_size * self.pp_size
) % self.nnodes == 0, "tp_size must be divisible by number of nodes" ) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
# FIXME pp constraints
if self.pp_size > 1: if self.pp_size > 1:
assert ( assert (
self.disable_overlap_schedule self.disable_overlap_schedule
...@@ -1807,11 +1798,7 @@ class ServerArgs: ...@@ -1807,11 +1798,7 @@ class ServerArgs:
assert not ( assert not (
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
), "multi-node data parallel is not supported unless dp attention!" ), "multi-node data parallel is not supported unless dp attention!"
assert (
self.max_loras_per_batch > 0
# FIXME
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
assert self.gpu_id_step >= 1, "gpu_id_step must be positive" assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
...@@ -1820,9 +1807,32 @@ class ServerArgs: ...@@ -1820,9 +1807,32 @@ class ServerArgs:
None, None,
}, "moe_dense_tp_size only support 1 and None currently" }, "moe_dense_tp_size only support 1 and None currently"
# Check model architecture
model_arch = self.get_hf_config().architectures[0]
if "Llama4" in model_arch:
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
# Check LoRA
self.check_lora_server_args() self.check_lora_server_args()
# Check speculative decoding
if self.speculative_algorithm is not None:
assert (
not self.enable_mixed_chunk
), "enable_mixed_chunk is required for speculative decoding"
# Check chunked prefill
assert (
self.chunked_prefill_size % self.page_size == 0
), "chunked_prefill_size must be divisible by page_size"
def check_lora_server_args(self): def check_lora_server_args(self):
assert (
self.max_loras_per_batch > 0
# FIXME
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and radix attention is in progress"
# Enable LoRA if any LoRA paths are provided for backward compatibility. # Enable LoRA if any LoRA paths are provided for backward compatibility.
if self.lora_paths: if self.lora_paths:
if self.enable_lora is None: if self.enable_lora is None:
......
...@@ -336,7 +336,6 @@ class EAGLEDraftCudaGraphRunner: ...@@ -336,7 +336,6 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.req_pool_indices = self.req_pool_indices[:bs] forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens] forward_batch.positions = self.positions[:num_tokens]
# Special handle for seq_len_cpu used when flashinfer mla is used
if forward_batch.seq_lens_cpu is not None: if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu.fill_(self.seq_len_fill_value)
......
...@@ -937,71 +937,6 @@ def monkey_patch_vllm_gguf_config(): ...@@ -937,71 +937,6 @@ def monkey_patch_vllm_gguf_config():
setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced) setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "sglang.srt.utils:CustomCacheManager"
logger.debug("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager
class CustomCacheManager(FileCacheManager):
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
def __init__(self, key, override=False, dump=False):
from sglang.srt.distributed.parallel_state import get_tp_group
self.key = key
self.lock_path = None
try:
module_path = "triton.runtime.cache"
cache_module = importlib.import_module(module_path)
default_cache_dir = getattr(cache_module, "default_cache_dir", None)
default_dump_dir = getattr(cache_module, "default_dump_dir", None)
default_override_dir = getattr(cache_module, "default_override_dir", None)
except (ModuleNotFoundError, AttributeError) as e:
default_cache_dir = None
default_dump_dir = None
default_override_dir = None
if dump:
self.cache_dir = (
default_dump_dir()
if default_dump_dir is not None
else os.path.join(Path.home(), ".triton", "dump")
)
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = (
default_override_dir()
if default_override_dir is not None
else os.path.join(Path.home(), ".triton", "override")
)
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
default_cache_dir()
if default_cache_dir is not None
else os.path.join(Path.home(), ".triton", "cache")
)
if self.cache_dir:
try:
self.cache_dir = f"{self.cache_dir}_{get_tp_group().local_rank}"
except:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
def set_ulimit(target_soft_limit=65535): def set_ulimit(target_soft_limit=65535):
# number of open files # number of open files
resource_type = resource.RLIMIT_NOFILE resource_type = resource.RLIMIT_NOFILE
......
...@@ -101,7 +101,7 @@ class TestDeepseekMTP(CustomTestCase): ...@@ -101,7 +101,7 @@ class TestDeepseekMTP(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"512", "512",
"--speculative-algorithm", "--speculative-algorithm",
"NEXTN", "EAGLE",
"--speculative-num-steps", "--speculative-num-steps",
"1", "1",
"--speculative-eagle-topk", "--speculative-eagle-topk",
......
...@@ -261,7 +261,7 @@ class TestMTP(CustomTestCase): ...@@ -261,7 +261,7 @@ class TestMTP(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-deepep-moe", "--enable-deepep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
"--speculative-num-steps", "--speculative-num-steps",
...@@ -329,7 +329,7 @@ class TestMTPWithTBO(CustomTestCase): ...@@ -329,7 +329,7 @@ class TestMTPWithTBO(CustomTestCase):
"--enable-deepep-moe", "--enable-deepep-moe",
"--trust-remote-code", "--trust-remote-code",
"--speculative-algorithm", "--speculative-algorithm",
"NEXTN", "EAGLE",
"--speculative-num-steps", "--speculative-num-steps",
"2", "2",
"--speculative-eagle-topk", "--speculative-eagle-topk",
......
...@@ -1224,7 +1224,7 @@ class Test30(CustomTestCase): ...@@ -1224,7 +1224,7 @@ class Test30(CustomTestCase):
"--tp", "--tp",
"8", "8",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1271,7 +1271,7 @@ class Test31(CustomTestCase): ...@@ -1271,7 +1271,7 @@ class Test31(CustomTestCase):
"--dp", "--dp",
"4", "4",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1318,7 +1318,7 @@ class Test32(CustomTestCase): ...@@ -1318,7 +1318,7 @@ class Test32(CustomTestCase):
"--dp", "--dp",
"8", "8",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1364,7 +1364,7 @@ class Test33(CustomTestCase): ...@@ -1364,7 +1364,7 @@ class Test33(CustomTestCase):
"--moe-dense-tp-size", "--moe-dense-tp-size",
"1", "1",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1413,7 +1413,7 @@ class Test34(CustomTestCase): ...@@ -1413,7 +1413,7 @@ class Test34(CustomTestCase):
"--moe-dense-tp-size", "--moe-dense-tp-size",
"1", "1",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1462,7 +1462,7 @@ class Test35(CustomTestCase): ...@@ -1462,7 +1462,7 @@ class Test35(CustomTestCase):
"--moe-dense-tp-size", "--moe-dense-tp-size",
"1", "1",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1510,7 +1510,7 @@ class Test36(CustomTestCase): ...@@ -1510,7 +1510,7 @@ class Test36(CustomTestCase):
"4", "4",
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1558,7 +1558,7 @@ class Test37(CustomTestCase): ...@@ -1558,7 +1558,7 @@ class Test37(CustomTestCase):
"8", "8",
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1608,7 +1608,7 @@ class Test38(CustomTestCase): ...@@ -1608,7 +1608,7 @@ class Test38(CustomTestCase):
"1", "1",
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1658,7 +1658,7 @@ class Test39(CustomTestCase): ...@@ -1658,7 +1658,7 @@ class Test39(CustomTestCase):
"1", "1",
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1709,7 +1709,7 @@ class Test40(CustomTestCase): ...@@ -1709,7 +1709,7 @@ class Test40(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"32", "32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1763,7 +1763,7 @@ class Test41(CustomTestCase): ...@@ -1763,7 +1763,7 @@ class Test41(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"32", "32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1817,7 +1817,7 @@ class Test42(CustomTestCase): ...@@ -1817,7 +1817,7 @@ class Test42(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"32", "32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1870,7 +1870,7 @@ class Test43(CustomTestCase): ...@@ -1870,7 +1870,7 @@ class Test43(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"32", "32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1926,7 +1926,7 @@ class Test44(CustomTestCase): ...@@ -1926,7 +1926,7 @@ class Test44(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"32", "32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -1982,7 +1982,7 @@ class Test45(CustomTestCase): ...@@ -1982,7 +1982,7 @@ class Test45(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"32", "32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2037,7 +2037,7 @@ class Test46(CustomTestCase): ...@@ -2037,7 +2037,7 @@ class Test46(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"32", "32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2092,7 +2092,7 @@ class Test47(CustomTestCase): ...@@ -2092,7 +2092,7 @@ class Test47(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"32", "32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2149,7 +2149,7 @@ class Test48(CustomTestCase): ...@@ -2149,7 +2149,7 @@ class Test48(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"32", "32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2206,7 +2206,7 @@ class Test49(CustomTestCase): ...@@ -2206,7 +2206,7 @@ class Test49(CustomTestCase):
"--max-running-requests", "--max-running-requests",
"32", "32",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2251,7 +2251,7 @@ class Test50(CustomTestCase): ...@@ -2251,7 +2251,7 @@ class Test50(CustomTestCase):
"8", "8",
"--enable-ep-moe", "--enable-ep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2299,7 +2299,7 @@ class Test51(CustomTestCase): ...@@ -2299,7 +2299,7 @@ class Test51(CustomTestCase):
"4", "4",
"--enable-ep-moe", "--enable-ep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2347,7 +2347,7 @@ class Test52(CustomTestCase): ...@@ -2347,7 +2347,7 @@ class Test52(CustomTestCase):
"8", "8",
"--enable-ep-moe", "--enable-ep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2394,7 +2394,7 @@ class Test53(CustomTestCase): ...@@ -2394,7 +2394,7 @@ class Test53(CustomTestCase):
"1", "1",
"--enable-ep-moe", "--enable-ep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2444,7 +2444,7 @@ class Test54(CustomTestCase): ...@@ -2444,7 +2444,7 @@ class Test54(CustomTestCase):
"1", "1",
"--enable-ep-moe", "--enable-ep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2494,7 +2494,7 @@ class Test55(CustomTestCase): ...@@ -2494,7 +2494,7 @@ class Test55(CustomTestCase):
"1", "1",
"--enable-ep-moe", "--enable-ep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2543,7 +2543,7 @@ class Test56(CustomTestCase): ...@@ -2543,7 +2543,7 @@ class Test56(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-ep-moe", "--enable-ep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2592,7 +2592,7 @@ class Test57(CustomTestCase): ...@@ -2592,7 +2592,7 @@ class Test57(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-ep-moe", "--enable-ep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2643,7 +2643,7 @@ class Test58(CustomTestCase): ...@@ -2643,7 +2643,7 @@ class Test58(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-ep-moe", "--enable-ep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
...@@ -2694,7 +2694,7 @@ class Test59(CustomTestCase): ...@@ -2694,7 +2694,7 @@ class Test59(CustomTestCase):
"--enable-dp-lm-head", "--enable-dp-lm-head",
"--enable-ep-moe", "--enable-ep-moe",
"--speculative-algo", "--speculative-algo",
"NEXTN", "EAGLE",
"--speculative-draft", "--speculative-draft",
"lmsys/DeepSeek-V3-0324-NextN", "lmsys/DeepSeek-V3-0324-NextN",
"--speculative-num-steps", "--speculative-num-steps",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment