Unverified Commit 935cda94 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Misc clean up; Remove the support of jump forward (#4032)

parent 110e0066
......@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional, Union
import torch
import triton
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.flashinfer_backend import (
create_flashinfer_kv_indices_triton,
)
......
......@@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
requantize_with_max_scale,
)
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
......
......@@ -57,7 +57,6 @@ DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 <
class DecodeStatus:
"""Store the status of incremental decoding."""
vid: int
decoded_text: str
decode_ids: List[int]
surr_offset: int
......@@ -143,10 +142,8 @@ class DetokenizerManager:
read_ids, surr_ids = [], []
for i in range(bs):
rid = recv_obj.rids[i]
vid = recv_obj.vids[i]
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
if rid not in self.decode_status:
s = DecodeStatus(
vid=vid,
decoded_text=recv_obj.decoded_texts[i],
decode_ids=recv_obj.decode_ids[i],
surr_offset=0,
......
......@@ -376,8 +376,6 @@ class BatchTokenIDOut:
# The finish reason
finished_reasons: List[BaseFinishReason]
# For incremental decoding
# The version id to sync decode status with in detokenizer_manager
vids: List[int]
decoded_texts: List[str]
decode_ids: List[int]
read_offsets: List[int]
......
......@@ -296,7 +296,6 @@ class Req:
# 1: surr_offset
# 2: read_offset
# 3: last token
self.vid = 0 # version id to sync decode status with in detokenizer_manager
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None
self.decoded_text = ""
......@@ -357,11 +356,6 @@ class Req:
) = None
self.hidden_states = []
# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
self.last_update_decode_tokens = 0
# Embedding (return values)
self.embedding = None
......@@ -500,68 +494,6 @@ class Req:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
self.origin_input_ids_unpadded
)
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
all_ids = self.tokenizer.encode(all_text)
if not all_ids:
logger.warning("Encoded all_text resulted in empty all_ids")
return False
prompt_tokens = len(self.origin_input_ids_unpadded)
if prompt_tokens > len(all_ids):
logger.warning("prompt_tokens is larger than encoded all_ids")
return False
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion
logger.warning(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return False
old_output_ids = self.output_ids
self.output_ids = all_ids[prompt_tokens:]
self.decoded_text = self.decoded_text + jump_forward_str
self.surr_offset = prompt_tokens
self.read_offset = len(all_ids)
# NOTE: A trick to reduce the surrouding tokens decoding overhead
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
surr_text_ = self.tokenizer.decode(
all_ids[self.read_offset - i : self.read_offset]
)
if not surr_text_.endswith("�"):
self.surr_offset = self.read_offset - i
break
# update the inner state of the grammar
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
if self.return_logprob:
# For fast-forward part's logprobs
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == self.output_ids[i]:
k = k + 1
else:
break
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k
return True
def reset_for_retract(self):
self.prefix_indices = []
self.last_node = None
......@@ -574,8 +506,6 @@ class Req:
self.is_chunked = 0
self.req_pool_idx = None
self.last_update_decode_tokens = 0
def __repr__(self):
return (
f"Req(rid={self.rid}, "
......@@ -672,7 +602,6 @@ class ScheduleBatch:
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
):
return cls(
reqs=reqs,
......@@ -687,7 +616,7 @@ class ScheduleBatch:
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
return_hidden_states=any(req.return_hidden_states for req in reqs),
)
def batch_size(self):
......@@ -1091,59 +1020,6 @@ class ScheduleBatch:
return retracted_reqs, new_estimate_ratio
def check_for_jump_forward(self, pad_input_ids_func):
jump_forward_reqs = []
keep_indices = set(i for i in range(len(self.reqs)))
for i, req in enumerate(self.reqs):
if req.grammar is not None:
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
if jump_helper:
suffix_ids, _ = jump_helper
# Current ids, for cache and revert
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
cur_output_ids = req.output_ids
req.output_ids.extend(suffix_ids)
decode_res, new_text = req.get_next_inc_detokenization()
if not decode_res:
req.output_ids = cur_output_ids
continue
(
jump_forward_str,
next_state,
) = req.grammar.jump_forward_str_state(jump_helper)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize(
jump_forward_str, next_state
):
req.output_ids = cur_output_ids
continue
# The decode status has diverged from detokenizer_manager
req.vid += 1
# insert the old request into tree_cache
self.tree_cache.cache_finished_req(req, cur_all_ids)
# re-applying image padding
if req.image_inputs is not None:
req.origin_input_ids = pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs
)
jump_forward_reqs.append(req)
keep_indices.remove(i)
self.filter_batch(keep_indices=list(keep_indices))
return jump_forward_reqs
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)
......
......@@ -150,7 +150,6 @@ class Scheduler:
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_policy = server_args.schedule_policy
self.disable_jump_forward = server_args.disable_jump_forward
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
......@@ -251,9 +250,6 @@ class Scheduler:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for multimodal models.")
if self.enable_overlap:
self.disable_jump_forward = True
# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
......@@ -1024,11 +1020,8 @@ class Scheduler:
if self.running_batch is not None
else set([])
)
return_hidden_states = False
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if req.return_hidden_states:
return_hidden_states = True
if (
self.lora_paths
and len(
......@@ -1114,7 +1107,6 @@ class Scheduler:
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
return_hidden_states,
)
new_batch.prepare_for_extend()
......@@ -1168,14 +1160,6 @@ class Scheduler:
self.min_new_token_ratio,
)
# Check for jump-forward
if not self.disable_jump_forward and batch.has_grammar:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self._extend_requests_to_queue(jump_forward_reqs)
if batch.is_empty():
self.batch_is_full = False
return None
if batch.batch_size() < initial_bs:
self.batch_is_full = False
......@@ -1530,8 +1514,6 @@ class Scheduler:
prefill (e.g., computing input token logprobs).
"""
assert output.input_token_logprobs is not None
# It is for jump decoding that will be deprecated.
assert req.last_update_decode_tokens == 0
if req.input_token_logprobs is None:
req.input_token_logprobs = []
if req.temp_input_top_logprobs_val is None:
......@@ -1658,50 +1640,12 @@ class Scheduler:
self.add_input_logprob_return_values(
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
)
if req.last_update_decode_tokens != 0:
# Some decode tokens are re-computed in an extend batch
req.output_token_logprobs_val.extend(
output.input_token_logprobs[
pt
+ num_input_logprobs
- 1
- req.last_update_decode_tokens : pt
+ num_input_logprobs
- 1
],
)
req.output_token_logprobs_idx.extend(
req.fill_ids[
len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids)
]
)
if req.top_logprobs_num > 0:
if req.last_update_decode_tokens != 0:
req.output_top_logprobs_val.extend(
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_idx.extend(
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
if req.token_ids_logprob is not None:
if req.last_update_decode_tokens != 0:
req.output_token_ids_logprobs_val.extend(
output.input_token_ids_logprobs_val[i][
-req.last_update_decode_tokens :
]
)
req.output_token_ids_logprobs_idx.extend(
output.input_token_ids_logprobs_idx[i][
-req.last_update_decode_tokens :
]
)
req.output_token_ids_logprobs_val.append(
output.next_token_token_ids_logprobs_val[i]
)
......@@ -1719,7 +1663,6 @@ class Scheduler:
finished_reasons: List[BaseFinishReason] = []
if self.is_generation:
vids = []
decoded_texts = []
decode_ids_list = []
read_offsets = []
......@@ -1786,7 +1729,6 @@ class Scheduler:
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
)
vids.append(req.vid)
decoded_texts.append(req.decoded_text)
decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids)
......@@ -1842,7 +1784,6 @@ class Scheduler:
BatchTokenIDOut(
rids,
finished_reasons,
vids,
decoded_texts,
decode_ids_list,
read_offsets,
......
......@@ -41,7 +41,7 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
......
......@@ -26,8 +26,6 @@ from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import ORJSONResponse, StreamingResponse
from pydantic import ValidationError
from sglang.lang.chat_template import get_chat_template_by_model_path
try:
from outlines.fsm.json_schema import convert_json_schema_to_str
except ImportError:
......@@ -165,24 +163,19 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
else:
chat_template_name = chat_template_arg
# check chat-template
chat_template = get_chat_template_by_model_path(model_path)
if chat_template is not None:
official_chat_template = chat_template.name
used_chat_template = chat_template_name
if official_chat_template != used_chat_template:
logger.warning(
f"Using a chat_template: '{used_chat_template}', "
f"which is different from official chat template: '{official_chat_template}', "
f"This discrepancy may lead to performance degradation."
)
# Check chat-template
# TODO:
# 1. Do not import any code from sglang.lang
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
async def v1_files_create(
file: UploadFile, purpose: str, file_storage_path: str = None
):
try:
global storage_dir
if file_storage_pth:
storage_dir = file_storage_pth
if file_storage_path:
storage_dir = file_storage_path
# Read the file content
file_content = await file.read()
......
......@@ -40,17 +40,23 @@ class SamplingParams:
presence_penalty: float = 0.0,
repetition_penalty: float = 1.0,
min_new_tokens: int = 0,
spaces_between_special_tokens: bool = True,
n: int = 1,
json_schema: Optional[str] = None,
regex: Optional[str] = None,
ebnf: Optional[str] = None,
structural_tag: Optional[str] = None,
no_stop_trim: bool = False,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
no_stop_trim: bool = False,
custom_params: Optional[Dict[str, Any]] = None,
) -> None:
self.max_new_tokens = max_new_tokens
self.stop_strs = stop
if stop_token_ids:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
......@@ -58,26 +64,21 @@ class SamplingParams:
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty
self.stop_strs = stop
if stop_token_ids:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.max_new_tokens = max_new_tokens
self.min_new_tokens = min_new_tokens
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.regex = regex
self.n = n
self.json_schema = json_schema
self.ebnf = ebnf
self.structural_tag = structural_tag
self.ignore_eos = ignore_eos
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens
self.no_stop_trim = no_stop_trim
self.custom_params = custom_params
# Process some special cases
if self.temperature < _SAMPLING_EPS:
# top_k = 1 means greedy sampling
self.temperature = 1.0
self.top_k = 1
if self.top_k == -1:
......
......@@ -15,21 +15,15 @@
import argparse
import dataclasses
import json
import logging
import os
import random
import subprocess
import tempfile
import uuid
from pathlib import Path
from typing import List, Optional
import torch
from sglang.srt.hf_transformers_utils import check_gguf_file
from sglang.srt.utils import (
create_checksum,
get_amdgpu_memory_capacity,
get_hpu_memory_capacity,
get_nvgpu_memory_capacity,
......@@ -101,7 +95,7 @@ class ServerArgs:
# API related
api_key: Optional[str] = None
file_storage_pth: str = "sglang_storage"
file_storage_path: str = "sglang_storage"
enable_cache_report: bool = False
# Data parallelism
......@@ -149,7 +143,6 @@ class ServerArgs:
# Optimization/debug options
disable_radix_cache: bool = False
disable_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
enable_nccl_nvls: bool = False
......@@ -627,9 +620,9 @@ class ServerArgs:
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
)
parser.add_argument(
"--file-storage-pth",
"--file-storage-path",
type=str,
default=ServerArgs.file_storage_pth,
default=ServerArgs.file_storage_path,
help="The path of the file storage in backend.",
)
parser.add_argument(
......@@ -836,11 +829,6 @@ class ServerArgs:
action="store_true",
help="Disable RadixAttention for prefix caching.",
)
parser.add_argument(
"--disable-jump-forward",
action="store_true",
help="Disable jump-forward for grammar-guided decoding.",
)
parser.add_argument(
"--disable-cuda-graph",
action="store_true",
......
......@@ -44,7 +44,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B"
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
def is_in_ci():
......
"""
Usage:
# single GPU
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B
python3 bench_speculative.py --model-path meta-llama/Llama-2-7b-chat-hf --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B
"""
import argparse
......
......@@ -17,3 +17,59 @@ For CUDA 12.1 or CUDA 12.4:
```bash
pip3 install sgl-kernel
```
# Developer Guide
## Development Environment Setup
Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container).
Create and enter development container:
```bash
docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
docker exec -it sglang_zhyncs /bin/zsh
```
## Project Structure
### Dependencies
Third-party libraries:
- [CCCL](https://github.com/NVIDIA/cccl)
- [CUTLASS](https://github.com/NVIDIA/cutlass)
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
- [TurboMind](https://github.com/InternLM/turbomind)
### Kernel Development
Steps to add a new kernel:
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h)
3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
### Build & Install
Development build:
```bash
make build
```
Note:
The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.
### Testing & Benchmarking
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests)
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
3. Run test suite
### Release new version
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/version.py)
# Developer Guide for sgl-kernel
## Development Environment Setup
Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container).
Create and enter development container:
```bash
docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
docker exec -it sglang_zhyncs /bin/zsh
```
## Project Structure
### Dependencies
Third-party libraries:
- [CCCL](https://github.com/NVIDIA/cccl)
- [CUTLASS](https://github.com/NVIDIA/cutlass)
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
- [TurboMind](https://github.com/InternLM/turbomind)
### Kernel Development
Steps to add a new kernel:
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h)
3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc)
4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
### Build & Install
Development build:
```bash
make build
```
Note:
The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.
### Testing & Benchmarking
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests)
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
3. Run test suite
### Release new version
Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/version.py)
......@@ -100,6 +100,7 @@ sources = [
"src/sgl-kernel/csrc/activation/fused_add_rms_norm_kernel.cu",
"src/sgl-kernel/csrc/allreduce/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/allreduce/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/attention/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/gemm/cublas_grouped_gemm.cu",
"src/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu",
"src/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu",
......@@ -108,7 +109,6 @@ sources = [
"src/sgl-kernel/csrc/moe/moe_align_kernel.cu",
"src/sgl-kernel/csrc/speculative/eagle_utils.cu",
"src/sgl-kernel/csrc/speculative/speculative_sampling.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/norm.cu",
......
......@@ -62,6 +62,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
m.impl("register_graph_buffers", torch::kCUDA, &register_graph_buffers);
/*
* From csrc/attention
*/
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
/*
* From csrc/gemm
*/
......@@ -163,11 +168,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
/*
* Other
*/
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
}
REGISTER_EXTENSION(_kernels)
......@@ -46,7 +46,6 @@ class TestEBNFConstrained(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, "xgrammar", disable_overlap=False)
cls.check_jump_forward = False
@classmethod
def tearDownClass(cls):
......@@ -238,12 +237,5 @@ class TestEBNFConstrained(unittest.TestCase):
)
class TestEBNFConstrainedLLGuidance(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, "llguidance", disable_overlap=False)
cls.check_jump_forward = False
if __name__ == "__main__":
unittest.main()
......@@ -57,7 +57,6 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="outlines", disable_overlap=False)
cls.check_jump_forward = False
@classmethod
def tearDownClass(cls):
......@@ -134,26 +133,5 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
list(executor.map(self.run_decode, json_schemas))
class TestJumpForwardOutlinesBackend(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="outlines", disable_overlap=True)
cls.check_jump_forward = True
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="xgrammar", disable_overlap=False)
cls.check_jump_forward = False
class TestJSONConstrainedLLGuidanceBackend(TestJSONConstrainedOutlinesBackend):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="llguidance", disable_overlap=False)
cls.check_jump_forward = False
if __name__ == "__main__":
unittest.main()
......@@ -12,7 +12,9 @@ from sglang.test.test_utils import (
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server,
write_github_step_summary,
)
......@@ -49,6 +51,9 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.62)
if is_in_ci():
write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n')
def test_human_eval(self):
args = SimpleNamespace(
base_url=self.base_url,
......@@ -61,6 +66,11 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.40)
if is_in_ci():
write_github_step_summary(
f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n'
)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
......@@ -73,6 +83,11 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.61)
if is_in_ci():
write_github_step_summary(
f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n'
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment