Unverified Commit 935cda94 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Misc clean up; Remove the support of jump forward (#4032)

parent 110e0066
...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union ...@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch import torch
import triton import triton
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton, create_flashinfer_kv_indices_triton,
) )
......
...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
......
...@@ -57,7 +57,6 @@ DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 < ...@@ -57,7 +57,6 @@ DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 <
class DecodeStatus: class DecodeStatus:
"""Store the status of incremental decoding.""" """Store the status of incremental decoding."""
vid: int
decoded_text: str decoded_text: str
decode_ids: List[int] decode_ids: List[int]
surr_offset: int surr_offset: int
...@@ -143,10 +142,8 @@ class DetokenizerManager: ...@@ -143,10 +142,8 @@ class DetokenizerManager:
read_ids, surr_ids = [], [] read_ids, surr_ids = [], []
for i in range(bs): for i in range(bs):
rid = recv_obj.rids[i] rid = recv_obj.rids[i]
vid = recv_obj.vids[i] if rid not in self.decode_status:
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
s = DecodeStatus( s = DecodeStatus(
vid=vid,
decoded_text=recv_obj.decoded_texts[i], decoded_text=recv_obj.decoded_texts[i],
decode_ids=recv_obj.decode_ids[i], decode_ids=recv_obj.decode_ids[i],
surr_offset=0, surr_offset=0,
......
...@@ -376,8 +376,6 @@ class BatchTokenIDOut: ...@@ -376,8 +376,6 @@ class BatchTokenIDOut:
# The finish reason # The finish reason
finished_reasons: List[BaseFinishReason] finished_reasons: List[BaseFinishReason]
# For incremental decoding # For incremental decoding
# The version id to sync decode status with in detokenizer_manager
vids: List[int]
decoded_texts: List[str] decoded_texts: List[str]
decode_ids: List[int] decode_ids: List[int]
read_offsets: List[int] read_offsets: List[int]
......
...@@ -296,7 +296,6 @@ class Req: ...@@ -296,7 +296,6 @@ class Req:
# 1: surr_offset # 1: surr_offset
# 2: read_offset # 2: read_offset
# 3: last token # 3: last token
self.vid = 0 # version id to sync decode status with in detokenizer_manager
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None self.read_offset = None
self.decoded_text = "" self.decoded_text = ""
...@@ -357,11 +356,6 @@ class Req: ...@@ -357,11 +356,6 @@ class Req:
) = None ) = None
self.hidden_states = [] self.hidden_states = []
# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
self.last_update_decode_tokens = 0
# Embedding (return values) # Embedding (return values)
self.embedding = None self.embedding = None
...@@ -500,68 +494,6 @@ class Req: ...@@ -500,68 +494,6 @@ class Req:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return return
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
self.origin_input_ids_unpadded
)
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
all_ids = self.tokenizer.encode(all_text)
if not all_ids:
logger.warning("Encoded all_text resulted in empty all_ids")
return False
prompt_tokens = len(self.origin_input_ids_unpadded)
if prompt_tokens > len(all_ids):
logger.warning("prompt_tokens is larger than encoded all_ids")
return False
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion
logger.warning(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return False
old_output_ids = self.output_ids
self.output_ids = all_ids[prompt_tokens:]
self.decoded_text = self.decoded_text + jump_forward_str
self.surr_offset = prompt_tokens
self.read_offset = len(all_ids)
# NOTE: A trick to reduce the surrouding tokens decoding overhead
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
surr_text_ = self.tokenizer.decode(
all_ids[self.read_offset - i : self.read_offset]
)
if not surr_text_.endswith("�"):
self.surr_offset = self.read_offset - i
break
# update the inner state of the grammar
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
if self.return_logprob:
# For fast-forward part's logprobs
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == self.output_ids[i]:
k = k + 1
else:
break
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k
return True
def reset_for_retract(self): def reset_for_retract(self):
self.prefix_indices = [] self.prefix_indices = []
self.last_node = None self.last_node = None
...@@ -574,8 +506,6 @@ class Req: ...@@ -574,8 +506,6 @@ class Req:
self.is_chunked = 0 self.is_chunked = 0
self.req_pool_idx = None self.req_pool_idx = None
self.last_update_decode_tokens = 0
def __repr__(self): def __repr__(self):
return ( return (
f"Req(rid={self.rid}, " f"Req(rid={self.rid}, "
...@@ -672,7 +602,6 @@ class ScheduleBatch: ...@@ -672,7 +602,6 @@ class ScheduleBatch:
enable_overlap: bool, enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm, spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool, enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
): ):
return cls( return cls(
reqs=reqs, reqs=reqs,
...@@ -687,7 +616,7 @@ class ScheduleBatch: ...@@ -687,7 +616,7 @@ class ScheduleBatch:
device=req_to_token_pool.device, device=req_to_token_pool.device,
spec_algorithm=spec_algorithm, spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor, enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states, return_hidden_states=any(req.return_hidden_states for req in reqs),
) )
def batch_size(self): def batch_size(self):
...@@ -1091,59 +1020,6 @@ class ScheduleBatch: ...@@ -1091,59 +1020,6 @@ class ScheduleBatch:
return retracted_reqs, new_estimate_ratio return retracted_reqs, new_estimate_ratio
def check_for_jump_forward(self, pad_input_ids_func):
jump_forward_reqs = []
keep_indices = set(i for i in range(len(self.reqs)))
for i, req in enumerate(self.reqs):
if req.grammar is not None:
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
if jump_helper:
suffix_ids, _ = jump_helper
# Current ids, for cache and revert
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
cur_output_ids = req.output_ids
req.output_ids.extend(suffix_ids)
decode_res, new_text = req.get_next_inc_detokenization()
if not decode_res:
req.output_ids = cur_output_ids
continue
(
jump_forward_str,
next_state,
) = req.grammar.jump_forward_str_state(jump_helper)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize(
jump_forward_str, next_state
):
req.output_ids = cur_output_ids
continue
# The decode status has diverged from detokenizer_manager
req.vid += 1
# insert the old request into tree_cache
self.tree_cache.cache_finished_req(req, cur_all_ids)
# re-applying image padding
if req.image_inputs is not None:
req.origin_input_ids = pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs
)
jump_forward_reqs.append(req)
keep_indices.remove(i)
self.filter_batch(keep_indices=list(keep_indices))
return jump_forward_reqs
def prepare_encoder_info_decode(self): def prepare_encoder_info_decode(self):
# Reset the encoder cached status # Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs) self.encoder_cached = [True] * len(self.reqs)
......
...@@ -150,7 +150,6 @@ class Scheduler: ...@@ -150,7 +150,6 @@ class Scheduler:
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.schedule_policy = server_args.schedule_policy self.schedule_policy = server_args.schedule_policy
self.disable_jump_forward = server_args.disable_jump_forward
self.lora_paths = server_args.lora_paths self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule self.enable_overlap = not server_args.disable_overlap_schedule
...@@ -251,9 +250,6 @@ class Scheduler: ...@@ -251,9 +250,6 @@ class Scheduler:
self.enable_overlap = False self.enable_overlap = False
logger.info("Overlap scheduler is disabled for multimodal models.") logger.info("Overlap scheduler is disabled for multimodal models.")
if self.enable_overlap:
self.disable_jump_forward = True
# Launch a tensor parallel worker # Launch a tensor parallel worker
if self.enable_overlap: if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient TpWorkerClass = TpModelWorkerClient
...@@ -1024,11 +1020,8 @@ class Scheduler: ...@@ -1024,11 +1020,8 @@ class Scheduler:
if self.running_batch is not None if self.running_batch is not None
else set([]) else set([])
) )
return_hidden_states = False
# Get requests from the waiting queue to a new prefill batch # Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue: for req in self.waiting_queue:
if req.return_hidden_states:
return_hidden_states = True
if ( if (
self.lora_paths self.lora_paths
and len( and len(
...@@ -1114,7 +1107,6 @@ class Scheduler: ...@@ -1114,7 +1107,6 @@ class Scheduler:
self.enable_overlap, self.enable_overlap,
self.spec_algorithm, self.spec_algorithm,
self.server_args.enable_custom_logit_processor, self.server_args.enable_custom_logit_processor,
return_hidden_states,
) )
new_batch.prepare_for_extend() new_batch.prepare_for_extend()
...@@ -1168,14 +1160,6 @@ class Scheduler: ...@@ -1168,14 +1160,6 @@ class Scheduler:
self.min_new_token_ratio, self.min_new_token_ratio,
) )
# Check for jump-forward
if not self.disable_jump_forward and batch.has_grammar:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self._extend_requests_to_queue(jump_forward_reqs)
if batch.is_empty():
self.batch_is_full = False
return None
if batch.batch_size() < initial_bs: if batch.batch_size() < initial_bs:
self.batch_is_full = False self.batch_is_full = False
...@@ -1530,8 +1514,6 @@ class Scheduler: ...@@ -1530,8 +1514,6 @@ class Scheduler:
prefill (e.g., computing input token logprobs). prefill (e.g., computing input token logprobs).
""" """
assert output.input_token_logprobs is not None assert output.input_token_logprobs is not None
# It is for jump decoding that will be deprecated.
assert req.last_update_decode_tokens == 0
if req.input_token_logprobs is None: if req.input_token_logprobs is None:
req.input_token_logprobs = [] req.input_token_logprobs = []
if req.temp_input_top_logprobs_val is None: if req.temp_input_top_logprobs_val is None:
...@@ -1658,50 +1640,12 @@ class Scheduler: ...@@ -1658,50 +1640,12 @@ class Scheduler:
self.add_input_logprob_return_values( self.add_input_logprob_return_values(
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
) )
if req.last_update_decode_tokens != 0:
# Some decode tokens are re-computed in an extend batch
req.output_token_logprobs_val.extend(
output.input_token_logprobs[
pt
+ num_input_logprobs
- 1
- req.last_update_decode_tokens : pt
+ num_input_logprobs
- 1
],
)
req.output_token_logprobs_idx.extend(
req.fill_ids[
len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids)
]
)
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
if req.last_update_decode_tokens != 0:
req.output_top_logprobs_val.extend(
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_idx.extend(
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
if req.token_ids_logprob is not None: if req.token_ids_logprob is not None:
if req.last_update_decode_tokens != 0:
req.output_token_ids_logprobs_val.extend(
output.input_token_ids_logprobs_val[i][
-req.last_update_decode_tokens :
]
)
req.output_token_ids_logprobs_idx.extend(
output.input_token_ids_logprobs_idx[i][
-req.last_update_decode_tokens :
]
)
req.output_token_ids_logprobs_val.append( req.output_token_ids_logprobs_val.append(
output.next_token_token_ids_logprobs_val[i] output.next_token_token_ids_logprobs_val[i]
) )
...@@ -1719,7 +1663,6 @@ class Scheduler: ...@@ -1719,7 +1663,6 @@ class Scheduler:
finished_reasons: List[BaseFinishReason] = [] finished_reasons: List[BaseFinishReason] = []
if self.is_generation: if self.is_generation:
vids = []
decoded_texts = [] decoded_texts = []
decode_ids_list = [] decode_ids_list = []
read_offsets = [] read_offsets = []
...@@ -1786,7 +1729,6 @@ class Scheduler: ...@@ -1786,7 +1729,6 @@ class Scheduler:
finished_reasons.append( finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None req.finished_reason.to_json() if req.finished_reason else None
) )
vids.append(req.vid)
decoded_texts.append(req.decoded_text) decoded_texts.append(req.decoded_text)
decode_ids, read_offset = req.init_incremental_detokenize() decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids) decode_ids_list.append(decode_ids)
...@@ -1842,7 +1784,6 @@ class Scheduler: ...@@ -1842,7 +1784,6 @@ class Scheduler:
BatchTokenIDOut( BatchTokenIDOut(
rids, rids,
finished_reasons, finished_reasons,
vids,
decoded_texts, decoded_texts,
decode_ids_list, decode_ids_list,
read_offsets, read_offsets,
......
...@@ -41,7 +41,7 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding ...@@ -41,7 +41,7 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
......
...@@ -26,8 +26,6 @@ from fastapi import HTTPException, Request, UploadFile ...@@ -26,8 +26,6 @@ from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
from pydantic import ValidationError from pydantic import ValidationError
from sglang.lang.chat_template import get_chat_template_by_model_path
try: try:
from outlines.fsm.json_schema import convert_json_schema_to_str from outlines.fsm.json_schema import convert_json_schema_to_str
except ImportError: except ImportError:
...@@ -165,24 +163,19 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode ...@@ -165,24 +163,19 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
else: else:
chat_template_name = chat_template_arg chat_template_name = chat_template_arg
# check chat-template # Check chat-template
chat_template = get_chat_template_by_model_path(model_path) # TODO:
if chat_template is not None: # 1. Do not import any code from sglang.lang
official_chat_template = chat_template.name # 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
used_chat_template = chat_template_name
if official_chat_template != used_chat_template:
logger.warning(
f"Using a chat_template: '{used_chat_template}', "
f"which is different from official chat template: '{official_chat_template}', "
f"This discrepancy may lead to performance degradation."
)
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None): async def v1_files_create(
file: UploadFile, purpose: str, file_storage_path: str = None
):
try: try:
global storage_dir global storage_dir
if file_storage_pth: if file_storage_path:
storage_dir = file_storage_pth storage_dir = file_storage_path
# Read the file content # Read the file content
file_content = await file.read() file_content = await file.read()
......
...@@ -40,17 +40,23 @@ class SamplingParams: ...@@ -40,17 +40,23 @@ class SamplingParams:
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
min_new_tokens: int = 0, min_new_tokens: int = 0,
spaces_between_special_tokens: bool = True,
n: int = 1, n: int = 1,
json_schema: Optional[str] = None, json_schema: Optional[str] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
ebnf: Optional[str] = None, ebnf: Optional[str] = None,
structural_tag: Optional[str] = None, structural_tag: Optional[str] = None,
no_stop_trim: bool = False,
ignore_eos: bool = False, ignore_eos: bool = False,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
no_stop_trim: bool = False,
custom_params: Optional[Dict[str, Any]] = None, custom_params: Optional[Dict[str, Any]] = None,
) -> None: ) -> None:
self.max_new_tokens = max_new_tokens
self.stop_strs = stop
if stop_token_ids:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.temperature = temperature self.temperature = temperature
self.top_p = top_p self.top_p = top_p
self.top_k = top_k self.top_k = top_k
...@@ -58,26 +64,21 @@ class SamplingParams: ...@@ -58,26 +64,21 @@ class SamplingParams:
self.frequency_penalty = frequency_penalty self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty self.repetition_penalty = repetition_penalty
self.stop_strs = stop
if stop_token_ids:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.max_new_tokens = max_new_tokens
self.min_new_tokens = min_new_tokens self.min_new_tokens = min_new_tokens
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.regex = regex self.regex = regex
self.n = n self.n = n
self.json_schema = json_schema self.json_schema = json_schema
self.ebnf = ebnf self.ebnf = ebnf
self.structural_tag = structural_tag self.structural_tag = structural_tag
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.no_stop_trim = no_stop_trim self.no_stop_trim = no_stop_trim
self.custom_params = custom_params self.custom_params = custom_params
# Process some special cases # Process some special cases
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
# top_k = 1 means greedy sampling
self.temperature = 1.0 self.temperature = 1.0
self.top_k = 1 self.top_k = 1
if self.top_k == -1: if self.top_k == -1:
......
...@@ -15,21 +15,15 @@ ...@@ -15,21 +15,15 @@
import argparse import argparse
import dataclasses import dataclasses
import json
import logging import logging
import os
import random import random
import subprocess
import tempfile import tempfile
import uuid
from pathlib import Path
from typing import List, Optional from typing import List, Optional
import torch import torch
from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import ( from sglang.srt.utils import (
create_checksum,
get_amdgpu_memory_capacity, get_amdgpu_memory_capacity,
get_hpu_memory_capacity, get_hpu_memory_capacity,
get_nvgpu_memory_capacity, get_nvgpu_memory_capacity,
...@@ -101,7 +95,7 @@ class ServerArgs: ...@@ -101,7 +95,7 @@ class ServerArgs:
# API related # API related
api_key: Optional[str] = None api_key: Optional[str] = None
file_storage_pth: str = "sglang_storage" file_storage_path: str = "sglang_storage"
enable_cache_report: bool = False enable_cache_report: bool = False
# Data parallelism # Data parallelism
...@@ -149,7 +143,6 @@ class ServerArgs: ...@@ -149,7 +143,6 @@ class ServerArgs:
# Optimization/debug options # Optimization/debug options
disable_radix_cache: bool = False disable_radix_cache: bool = False
disable_jump_forward: bool = False
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False enable_nccl_nvls: bool = False
...@@ -627,9 +620,9 @@ class ServerArgs: ...@@ -627,9 +620,9 @@ class ServerArgs:
help="Set API key of the server. It is also used in the OpenAI API compatible server.", help="Set API key of the server. It is also used in the OpenAI API compatible server.",
) )
parser.add_argument( parser.add_argument(
"--file-storage-pth", "--file-storage-path",
type=str, type=str,
default=ServerArgs.file_storage_pth, default=ServerArgs.file_storage_path,
help="The path of the file storage in backend.", help="The path of the file storage in backend.",
) )
parser.add_argument( parser.add_argument(
...@@ -836,11 +829,6 @@ class ServerArgs: ...@@ -836,11 +829,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable RadixAttention for prefix caching.", help="Disable RadixAttention for prefix caching.",
) )
parser.add_argument(
"--disable-jump-forward",
action="store_true",
help="Disable jump-forward for grammar-guided decoding.",
)
parser.add_argument( parser.add_argument(
"--disable-cuda-graph", "--disable-cuda-graph",
action="store_true", action="store_true",
......
...@@ -44,7 +44,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8 ...@@ -44,7 +44,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct" DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf" DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B" DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
def is_in_ci(): def is_in_ci():
......
""" """
Usage: Usage:
# single GPU # single GPU
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B
""" """
import argparse import argparse
......
...@@ -17,3 +17,59 @@ For CUDA 12.1 or CUDA 12.4: ...@@ -17,3 +17,59 @@ For CUDA 12.1 or CUDA 12.4:
```bash ```bash
pip3 install sgl-kernel pip3 install sgl-kernel
``` ```
# Developer Guide
## Development Environment Setup
Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container).
Create and enter development container:
```bash
docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
docker exec -it sglang_zhyncs /bin/zsh
```
## Project Structure
### Dependencies
Third-party libraries:
- [CCCL](https://github.com/NVIDIA/cccl)
- [CUTLASS](https://github.com/NVIDIA/cutlass)
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
- [TurboMind](https://github.com/InternLM/turbomind)
### Kernel Development
Steps to add a new kernel:
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h)
3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
### Build & Install
Development build:
```bash
make build
```
Note:
The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.
### Testing & Benchmarking
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests)
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
3. Run test suite
### Release new version
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/version.py)
# Developer Guide for sgl-kernel
## Development Environment Setup
Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container).
Create and enter development container:
```bash
docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
docker exec -it sglang_zhyncs /bin/zsh
```
## Project Structure
### Dependencies
Third-party libraries:
- [CCCL](https://github.com/NVIDIA/cccl)
- [CUTLASS](https://github.com/NVIDIA/cutlass)
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
- [TurboMind](https://github.com/InternLM/turbomind)
### Kernel Development
Steps to add a new kernel:
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h)
3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
### Build & Install
Development build:
```bash
make build
```
Note:
The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.
### Testing & Benchmarking
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests)
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
3. Run test suite
### Release new version
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/version.py)
...@@ -100,6 +100,7 @@ sources = [ ...@@ -100,6 +100,7 @@ sources = [
"src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu", "src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu",
"src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu", "src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu", "src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu",
"src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu", "src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
...@@ -108,7 +109,6 @@ sources = [ ...@@ -108,7 +109,6 @@ sources = [
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu", "src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
"src/sgl-kernel/csrc/speculative/eagle_utils.cu", "src/sgl-kernel/csrc/speculative/eagle_utils.cu",
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu", "src/sgl-kernel/csrc/speculative/speculative_sampling.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/norm.cu", "3rdparty/flashinfer/csrc/norm.cu",
......
...@@ -62,6 +62,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -62,6 +62,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()"); m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
m.impl("register_graph_buffers", torch::kCUDA, &register_graph_buffers); m.impl("register_graph_buffers", torch::kCUDA, &register_graph_buffers);
/*
* From csrc/attention
*/
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
...@@ -163,11 +168,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -163,11 +168,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); "Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
/*
* Other
*/
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
} }
REGISTER_EXTENSION(_kernels) REGISTER_EXTENSION(_kernels)
...@@ -46,7 +46,6 @@ class TestEBNFConstrained(unittest.TestCase): ...@@ -46,7 +46,6 @@ class TestEBNFConstrained(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
setup_class(cls, "xgrammar", disable_overlap=False) setup_class(cls, "xgrammar", disable_overlap=False)
cls.check_jump_forward = False
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
...@@ -238,12 +237,5 @@ class TestEBNFConstrained(unittest.TestCase): ...@@ -238,12 +237,5 @@ class TestEBNFConstrained(unittest.TestCase):
) )
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=False)
cls.check_jump_forward = False
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -57,7 +57,6 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase): ...@@ -57,7 +57,6 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
setup_class(cls, backend="outlines", disable_overlap=False) setup_class(cls, backend="outlines", disable_overlap=False)
cls.check_jump_forward = False
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
...@@ -134,26 +133,5 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase): ...@@ -134,26 +133,5 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
list(executor.map(self.run_decode, json_schemas)) list(executor.map(self.run_decode, json_schemas))
class TestJumpForwardOutlinesBackend(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="outlines", disable_overlap=True)
cls.check_jump_forward = True
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="xgrammar", disable_overlap=False)
cls.check_jump_forward = False
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="llguidance", disable_overlap=False)
cls.check_jump_forward = False
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -12,7 +12,9 @@ from sglang.test.test_utils import ( ...@@ -12,7 +12,9 @@ from sglang.test.test_utils import (
DEFAULT_MOE_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server, popen_launch_server,
write_github_step_summary,
) )
...@@ -49,6 +51,9 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase): ...@@ -49,6 +51,9 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.62) self.assertGreater(metrics["score"], 0.62)
if is_in_ci():
write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n')
def test_human_eval(self): def test_human_eval(self):
args = SimpleNamespace( args = SimpleNamespace(
base_url=self.base_url, base_url=self.base_url,
...@@ -61,6 +66,11 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase): ...@@ -61,6 +66,11 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.40) self.assertGreater(metrics["score"], 0.40)
if is_in_ci():
write_github_step_summary(
f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n'
)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(
base_url=self.base_url, base_url=self.base_url,
...@@ -73,6 +83,11 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase): ...@@ -73,6 +83,11 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.61) self.assertGreater(metrics["score"], 0.61)
if is_in_ci():
write_github_step_summary(
f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n'
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment