"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "084669a43ecf6f5ad699dfcef6236de7135f2ad6"
Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
#include "ngram.h" #include "lookahead.h"
#include <limits> #include <limits>
#include <vector> #include <vector>
namespace ngram { namespace lookahead {
struct Node { struct Node {
std::unordered_map<int32_t, int32_t> next; std::unordered_map<int32_t, int32_t> next;
}; };
Ngram::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) { Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
Ngram::Result info; Lookahead::Result info;
std::vector<int32_t> prevs; std::vector<int32_t> prevs;
info.token.reserve(draft_token_num); info.token.reserve(draft_token_num);
prevs.reserve(draft_token_num); prevs.reserve(draft_token_num);
...@@ -50,7 +50,7 @@ Ngram::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& ...@@ -50,7 +50,7 @@ Ngram::Result fillResult(int last_token, int draft_token_num, std::vector<Node>&
return info; return info;
} }
Ngram::Ngram(size_t capacity, const Param& param) { Lookahead::Lookahead(size_t capacity, const Param& param) {
param_ = param; param_ = param;
nodes_.resize(capacity); nodes_.resize(capacity);
for (auto& node : nodes_) { for (auto& node : nodes_) {
...@@ -116,16 +116,17 @@ Ngram::Ngram(size_t capacity, const Param& param) { ...@@ -116,16 +116,17 @@ Ngram::Ngram(size_t capacity, const Param& param) {
} }
quit_flag_ = false; quit_flag_ = false;
insert_worker_ = std::thread(&Ngram::insert, this); insert_worker_ = std::thread(&Lookahead::insert, this);
} }
Ngram::~Ngram() { Lookahead::~Lookahead() {
quit_flag_ = true; quit_flag_ = true;
insert_queue_.close(); insert_queue_.close();
insert_worker_.join(); insert_worker_.join();
} }
std::vector<std::pair<TrieNode*, int32_t>> Ngram::match(const std::vector<int32_t>& tokens, size_t batch_size) const { std::vector<std::pair<TrieNode*, int32_t>>
Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
auto draft_token_num = param_.get_draft_token_num(batch_size); auto draft_token_num = param_.get_draft_token_num(batch_size);
auto min_match_window_size = param_.get_min_match_window_size(batch_size); auto min_match_window_size = param_.get_min_match_window_size(batch_size);
auto max_match_window_size = param_.max_match_window_size; auto max_match_window_size = param_.max_match_window_size;
...@@ -153,7 +154,7 @@ std::vector<std::pair<TrieNode*, int32_t>> Ngram::match(const std::vector<int32_ ...@@ -153,7 +154,7 @@ std::vector<std::pair<TrieNode*, int32_t>> Ngram::match(const std::vector<int32_
return result; return result;
} }
void Ngram::squeeze(size_t count) { void Lookahead::squeeze(size_t count) {
if (!(node_pool_.size() >= free_node_count_ + count)) { if (!(node_pool_.size() >= free_node_count_ + count)) {
throw std::runtime_error( throw std::runtime_error(
"Insufficient node size to release required nodes. " "Insufficient node size to release required nodes. "
...@@ -176,13 +177,13 @@ void Ngram::squeeze(size_t count) { ...@@ -176,13 +177,13 @@ void Ngram::squeeze(size_t count) {
} }
} }
void Ngram::synchronize() const { void Lookahead::synchronize() const {
while (!insert_queue_.empty()) { while (!insert_queue_.empty()) {
std::this_thread::sleep_for(std::chrono::microseconds(10)); std::this_thread::sleep_for(std::chrono::microseconds(10));
} }
} }
void Ngram::insert() { void Lookahead::insert() {
while (!quit_flag_) { while (!quit_flag_) {
std::vector<int32_t> data; std::vector<int32_t> data;
if (!insert_queue_.dequeue(data)) { if (!insert_queue_.dequeue(data)) {
...@@ -238,13 +239,13 @@ void Ngram::insert() { ...@@ -238,13 +239,13 @@ void Ngram::insert() {
} }
} }
void Ngram::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) { void Lookahead::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
for (auto&& token : tokens) { for (auto&& token : tokens) {
insert_queue_.enqueue(std::move(token)); insert_queue_.enqueue(std::move(token));
} }
} }
Ngram::Result Ngram::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const { Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size); std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) / double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
...@@ -283,7 +284,7 @@ Ngram::Result Ngram::matchBFS(const std::vector<int32_t>& tokens, size_t batch_s ...@@ -283,7 +284,7 @@ Ngram::Result Ngram::matchBFS(const std::vector<int32_t>& tokens, size_t batch_s
return fillResult(tokens.back(), draft_token_num + 1, tree, root); return fillResult(tokens.back(), draft_token_num + 1, tree, root);
} }
Ngram::Result Ngram::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const { Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size); std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
auto draft_token_num = param_.get_draft_token_num(batch_size); auto draft_token_num = param_.get_draft_token_num(batch_size);
...@@ -345,10 +346,10 @@ Ngram::Result Ngram::matchProb(const std::vector<int32_t>& tokens, size_t batch_ ...@@ -345,10 +346,10 @@ Ngram::Result Ngram::matchProb(const std::vector<int32_t>& tokens, size_t batch_
return fillResult(tokens.back(), draft_token_num + 1, tree, root); return fillResult(tokens.back(), draft_token_num + 1, tree, root);
} }
Ngram::Result Ngram::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const { Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
Result merged_result; Result merged_result;
auto match_func = param_.match_type == "BFS" ? &Ngram::matchBFS : &Ngram::matchProb; auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb;
for (const auto& tks : tokens) { for (const auto& tks : tokens) {
Result res = (this->*match_func)(tks, tokens.size()); Result res = (this->*match_func)(tks, tokens.size());
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end()); merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
...@@ -357,7 +358,7 @@ Ngram::Result Ngram::batchMatch(const std::vector<std::vector<int32_t>>& tokens) ...@@ -357,7 +358,7 @@ Ngram::Result Ngram::batchMatch(const std::vector<std::vector<int32_t>>& tokens)
return merged_result; return merged_result;
} }
void Ngram::Result::truncate(size_t n) { void Lookahead::Result::truncate(size_t n) {
if (n < token.size()) { if (n < token.size()) {
int full_n = token.size(); int full_n = token.size();
for (int i = 1; i < n; ++i) { for (int i = 1; i < n; ++i) {
...@@ -368,4 +369,4 @@ void Ngram::Result::truncate(size_t n) { ...@@ -368,4 +369,4 @@ void Ngram::Result::truncate(size_t n) {
} }
} }
} // namespace ngram } // namespace lookahead
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "param.h" #include "param.h"
#include "queue.h" #include "queue.h"
namespace ngram { namespace lookahead {
struct TrieNode { struct TrieNode {
std::unordered_map<int32_t, TrieNode*> child; std::unordered_map<int32_t, TrieNode*> child;
...@@ -34,7 +34,7 @@ struct TrieNode { ...@@ -34,7 +34,7 @@ struct TrieNode {
std::multiset<TrieNode*, CompareByFreq> sorted_children; std::multiset<TrieNode*, CompareByFreq> sorted_children;
}; };
class Ngram { class Lookahead {
std::vector<TrieNode> nodes_; std::vector<TrieNode> nodes_;
std::vector<TrieNode*> node_pool_; std::vector<TrieNode*> node_pool_;
size_t free_node_count_; size_t free_node_count_;
...@@ -61,12 +61,12 @@ class Ngram { ...@@ -61,12 +61,12 @@ class Ngram {
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_; std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_;
public: public:
Ngram(size_t capacity, const Param& param); Lookahead(size_t capacity, const Param& param);
Ngram() = default; Lookahead() = default;
~Ngram(); ~Lookahead();
static Ngram& instance() { static Lookahead& instance() {
static Ngram instance; static Lookahead instance;
return instance; return instance;
} }
...@@ -107,4 +107,4 @@ class Ngram { ...@@ -107,4 +107,4 @@ class Ngram {
void insert(); void insert();
}; };
} // namespace ngram } // namespace lookahead
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# from sglang.op.lookahead import Lookahead, Param
import logging import logging
import os import os
from typing import List, Tuple from typing import List, Tuple
...@@ -10,17 +12,17 @@ from torch.utils.cpp_extension import load ...@@ -10,17 +12,17 @@ from torch.utils.cpp_extension import load
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_abs_path = os.path.dirname(os.path.abspath(__file__)) _abs_path = os.path.dirname(os.path.abspath(__file__))
ngram_cache_cpp = load( lookahead_cache_cpp = load(
name="ngram_cache_cpp", name="lookahead_cache_cpp",
sources=[ sources=[
f"{_abs_path}/ngram_cache_binding.cpp", f"{_abs_path}/lookahead_cache_binding.cpp",
f"{_abs_path}/ngram.cpp", f"{_abs_path}/lookahead.cpp",
], ],
extra_cflags=["-O3", "-std=c++20"], extra_cflags=["-O3", "-std=c++20"],
) )
class NgramCache: class LookaheadCache:
def __init__( def __init__(
self, self,
branch_length=18, branch_length=18,
...@@ -32,7 +34,7 @@ class NgramCache: ...@@ -32,7 +34,7 @@ class NgramCache:
match_type="BFS", match_type="BFS",
capacity=1000000, capacity=1000000,
): ):
param = ngram_cache_cpp.Param() param = lookahead_cache_cpp.Param()
param.branch_length = branch_length param.branch_length = branch_length
param.min_match_window_size = min_match_window_size param.min_match_window_size = min_match_window_size
param.max_match_window_size = max_match_window_size param.max_match_window_size = max_match_window_size
...@@ -40,7 +42,7 @@ class NgramCache: ...@@ -40,7 +42,7 @@ class NgramCache:
param.max_bfs_breadth = max_bfs_breadth param.max_bfs_breadth = max_bfs_breadth
param.draft_token_num = draft_token_num param.draft_token_num = draft_token_num
param.match_type = match_type param.match_type = match_type
self.cache = ngram_cache_cpp.Ngram(capacity, param) self.cache = lookahead_cache_cpp.Lookahead(capacity, param)
self.default_mask = np.ones((1, 1), dtype=np.int64) self.default_mask = np.ones((1, 1), dtype=np.int64)
self.draft_token_num = draft_token_num self.draft_token_num = draft_token_num
...@@ -129,7 +131,7 @@ if __name__ == "__main__": ...@@ -129,7 +131,7 @@ if __name__ == "__main__":
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100], [1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
] ]
cache = NgramCache(branch_length=12, draft_token_num=8) cache = LookaheadCache(branch_length=12, draft_token_num=8)
cache.batch_put(token_ids) cache.batch_put(token_ids)
cache.synchronize() cache.synchronize()
......
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "ngram.h" #include "lookahead.h"
PYBIND11_MODULE(ngram_cache_cpp, m) { PYBIND11_MODULE(lookahead_cache_cpp, m) {
using namespace ngram; using namespace lookahead;
namespace py = pybind11; namespace py = pybind11;
m.doc() = ""; m.doc() = "";
py::class_<Ngram>(m, "Ngram") py::class_<Lookahead>(m, "Lookahead")
.def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param")) .def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param"))
.def("asyncInsert", &Ngram::asyncInsert, "") .def("asyncInsert", &Lookahead::asyncInsert, "")
.def("batchMatch", &Ngram::batchMatch, "") .def("batchMatch", &Lookahead::batchMatch, "")
.def("reset", &Ngram::reset, "") .def("reset", &Lookahead::reset, "")
.def("synchronize", &Ngram::synchronize, ""); .def("synchronize", &Lookahead::synchronize, "");
py::class_<Param>(m, "Param") py::class_<Param>(m, "Param")
.def(py::init<>()) .def(py::init<>())
...@@ -35,9 +35,9 @@ PYBIND11_MODULE(ngram_cache_cpp, m) { ...@@ -35,9 +35,9 @@ PYBIND11_MODULE(ngram_cache_cpp, m) {
.def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "") .def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
.def("detail", &Param::detail, ""); .def("detail", &Param::detail, "");
py::class_<Ngram::Result>(m, "Result") py::class_<Lookahead::Result>(m, "Result")
.def(py::init<>()) .def(py::init<>())
.def_readwrite("token", &Ngram::Result::token) .def_readwrite("token", &Lookahead::Result::token)
.def_readwrite("mask", &Ngram::Result::mask) .def_readwrite("mask", &Lookahead::Result::mask)
.def("truncate", &Ngram::Result::truncate); .def("truncate", &Lookahead::Result::truncate);
} }
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace ngram { namespace lookahead {
struct Param { struct Param {
bool enable; bool enable;
...@@ -122,4 +122,4 @@ struct Param { ...@@ -122,4 +122,4 @@ struct Param {
} }
}; };
} // namespace ngram } // namespace lookahead
...@@ -13,7 +13,6 @@ import triton ...@@ -13,7 +13,6 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.environ import envs
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.layers.sampler import apply_custom_logit_processor
...@@ -24,7 +23,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -24,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
if is_cuda(): if is_cuda():
...@@ -43,8 +42,8 @@ logger = logging.getLogger(__name__) ...@@ -43,8 +42,8 @@ logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes # Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0 SIMULATE_ACC_LEN = os.environ.get("SIMULATE_ACC_LEN")
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get() SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
...@@ -501,12 +500,13 @@ class EagleVerifyInput: ...@@ -501,12 +500,13 @@ class EagleVerifyInput:
deterministic=True, deterministic=True,
) )
if SIMULATE_ACC_LEN > 0.0: if SIMULATE_ACC_LEN:
# Do simulation # Do simulation
accept_index = _generate_simulated_accept_index( accept_index = _generate_simulated_accept_index(
accept_index=accept_index, accept_index=accept_index,
predict=predict, # mutable predict=predict, # mutable
accept_length=accept_length, # mutable accept_length=accept_length, # mutable
simulate_acc_len=SIMULATE_ACC_LEN,
bs=bs, bs=bs,
spec_steps=self.spec_steps, spec_steps=self.spec_steps,
) )
...@@ -1131,16 +1131,14 @@ def _generate_simulated_accept_index( ...@@ -1131,16 +1131,14 @@ def _generate_simulated_accept_index(
accept_index, accept_index,
predict, predict,
accept_length, accept_length,
simulate_acc_len,
bs, bs,
spec_steps, spec_steps,
simulate_acc_len: float = SIMULATE_ACC_LEN,
simulate_acc_method: str = SIMULATE_ACC_METHOD,
): ):
assert simulate_acc_len > 0.0 simulate_acc_len_float = float(simulate_acc_len)
if SIMULATE_ACC_METHOD == "multinomial":
if simulate_acc_method == "multinomial":
simulated_values = torch.normal( simulated_values = torch.normal(
mean=simulate_acc_len, mean=simulate_acc_len_float,
std=1.0, std=1.0,
size=(1,), size=(1,),
device="cpu", device="cpu",
...@@ -1148,19 +1146,19 @@ def _generate_simulated_accept_index( ...@@ -1148,19 +1146,19 @@ def _generate_simulated_accept_index(
# clamp simulated values to be between 1 and self.spec_steps # clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1) simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
simulate_acc_len = int(simulated_values.round().item()) simulate_acc_len = int(simulated_values.round().item())
elif simulate_acc_method == "match-expected": elif SIMULATE_ACC_METHOD == "match-expected":
# multinomial sampling does not match the expected length # multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests # we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to # but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample # match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length # either round down or round up of the expected length
simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len)) simulate_acc_len_float = max(1.0, min(spec_steps + 1, simulate_acc_len_float))
lower = int(simulate_acc_len // 1) lower = int(simulate_acc_len_float // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower upper = lower + 1 if lower < spec_steps + 1 else lower
if lower == upper: if lower == upper:
simulate_acc_len = lower simulate_acc_len = lower
else: else:
weight_upper = simulate_acc_len - lower weight_upper = simulate_acc_len_float - lower
weight_lower = 1.0 - weight_upper weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu") probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1) sampled_index = torch.multinomial(probs, num_samples=1)
......
...@@ -42,7 +42,7 @@ elif is_hip(): ...@@ -42,7 +42,7 @@ elif is_hip():
@dataclass @dataclass
class NgramVerifyInput: class LookaheadVerifyInput:
def __init__( def __init__(
self, self,
draft_token: torch.Tensor, draft_token: torch.Tensor,
...@@ -405,8 +405,8 @@ class NgramVerifyInput: ...@@ -405,8 +405,8 @@ class NgramVerifyInput:
return logits_output, self.verified_id, self.accept_length.sum().item() return logits_output, self.verified_id, self.accept_length.sum().item()
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): def filter_batch(self, new_indices: torch.Tensor):
pass pass
def merge_batch(self, spec_info: NgramVerifyInput): def merge_batch(self, spec_info: LookaheadVerifyInput):
pass pass
...@@ -12,8 +12,8 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch ...@@ -12,8 +12,8 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache from sglang.srt.speculative.cpp_lookahead.lookahead_cache import LookaheadCache
from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import broadcast_pyobj from sglang.srt.utils import broadcast_pyobj
...@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) ...@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
USE_FULL_MASK = True USE_FULL_MASK = True
class NGRAMWorker: class LOOKAHEADWorker:
def __init__( def __init__(
self, self,
server_args: ServerArgs, server_args: ServerArgs,
...@@ -38,9 +38,9 @@ class NGRAMWorker: ...@@ -38,9 +38,9 @@ class NGRAMWorker:
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.page_size = server_args.page_size self.page_size = server_args.page_size
self.draft_token_num: int = server_args.speculative_num_draft_tokens self.draft_token_num: int = server_args.speculative_num_draft_tokens
self.branch_length: int = server_args.speculative_ngram_branch_length self.branch_length: int = server_args.speculative_lookahead_branch_length
self.max_match_window_size: int = ( self.max_match_window_size: int = (
server_args.speculative_ngram_max_match_window_size server_args.speculative_lookahead_max_match_window_size
) )
self.max_batch_size = target_worker.max_running_requests self.max_batch_size = target_worker.max_running_requests
...@@ -48,18 +48,18 @@ class NGRAMWorker: ...@@ -48,18 +48,18 @@ class NGRAMWorker:
self._init_preallocated_tensors() self._init_preallocated_tensors()
self.ngram_cache = NgramCache( self.lookahead_cache = LookaheadCache(
min_match_window_size=server_args.speculative_ngram_min_match_window_size, min_match_window_size=server_args.speculative_lookahead_min_match_window_size,
max_match_window_size=server_args.speculative_ngram_max_match_window_size, max_match_window_size=server_args.speculative_lookahead_max_match_window_size,
min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth, min_bfs_breadth=server_args.speculative_lookahead_min_bfs_breadth,
max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth, max_bfs_breadth=server_args.speculative_lookahead_max_bfs_breadth,
capacity=server_args.speculative_ngram_capacity, capacity=server_args.speculative_lookahead_capacity,
branch_length=server_args.speculative_ngram_branch_length, branch_length=server_args.speculative_lookahead_branch_length,
draft_token_num=server_args.speculative_num_draft_tokens, draft_token_num=server_args.speculative_num_draft_tokens,
) )
def clear_cache_pool(self): def clear_cache_pool(self):
self.ngram_cache.reset() self.lookahead_cache.reset()
def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int): def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
seq2_len = len(seq2) seq2_len = len(seq2)
...@@ -124,14 +124,14 @@ class NGRAMWorker: ...@@ -124,14 +124,14 @@ class NGRAMWorker:
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[np.ndarray, np.ndarray]:
bs = batch.batch_size() bs = batch.batch_size()
self.ngram_cache.synchronize() self.lookahead_cache.synchronize()
batch_tokens = [] batch_tokens = []
for req in batch.reqs: for req in batch.reqs:
check_token = self._efficient_concat_last_n( check_token = self._efficient_concat_last_n(
req.origin_input_ids, req.output_ids, self.max_match_window_size req.origin_input_ids, req.output_ids, self.max_match_window_size
) )
batch_tokens.append(check_token) batch_tokens.append(check_token)
req_drafts, mask = self.ngram_cache.batch_get(batch_tokens) req_drafts, mask = self.lookahead_cache.batch_get(batch_tokens)
total_draft_token_num = len(req_drafts) total_draft_token_num = len(req_drafts)
# Check if speculative decoding is needed; here we always enforce it # Check if speculative decoding is needed; here we always enforce it
...@@ -184,9 +184,9 @@ class NGRAMWorker: ...@@ -184,9 +184,9 @@ class NGRAMWorker:
tree_mask.append(req_mask.flatten()) tree_mask.append(req_mask.flatten())
tree_mask = torch.cat(tree_mask, dim=0) tree_mask = torch.cat(tree_mask, dim=0)
batch.spec_algorithm = SpeculativeAlgorithm.NGRAM batch.spec_algorithm = SpeculativeAlgorithm.LOOKAHEAD
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = NgramVerifyInput( batch.spec_info = LookaheadVerifyInput(
draft_tokens, draft_tokens,
tree_mask, tree_mask,
positions, positions,
...@@ -197,7 +197,7 @@ class NGRAMWorker: ...@@ -197,7 +197,7 @@ class NGRAMWorker:
) )
batch.spec_info.prepare_for_verify(batch, self.page_size) batch.spec_info.prepare_for_verify(batch, self.page_size)
def _update_ngram_cache(self, batch: ScheduleBatch): def _update_lookahead_cache(self, batch: ScheduleBatch):
batch_tokens = [] batch_tokens = []
for req in batch.reqs: for req in batch.reqs:
# FIXME: Whether to insert 'extend' into the cache or not, after testing, # FIXME: Whether to insert 'extend' into the cache or not, after testing,
...@@ -209,7 +209,7 @@ class NGRAMWorker: ...@@ -209,7 +209,7 @@ class NGRAMWorker:
req.origin_input_ids, req.output_ids, self.branch_length req.origin_input_ids, req.output_ids, self.branch_length
) )
batch_tokens.append(put_ids) batch_tokens.append(put_ids)
self.ngram_cache.batch_put(batch_tokens) self.lookahead_cache.batch_put(batch_tokens)
def forward_batch_speculative_generation(self, batch: ScheduleBatch): def forward_batch_speculative_generation(self, batch: ScheduleBatch):
self._prepare_for_speculative_decoding(batch) self._prepare_for_speculative_decoding(batch)
...@@ -227,7 +227,7 @@ class NGRAMWorker: ...@@ -227,7 +227,7 @@ class NGRAMWorker:
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify( logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
batch, logits_output, self.page_size batch, logits_output, self.page_size
) )
self._update_ngram_cache(batch) self._update_lookahead_cache(batch)
batch.forward_mode = ForwardMode.DECODE batch.forward_mode = ForwardMode.DECODE
else: else:
......
...@@ -6,7 +6,7 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -6,7 +6,7 @@ class SpeculativeAlgorithm(IntEnum):
EAGLE = auto() EAGLE = auto()
EAGLE3 = auto() EAGLE3 = auto()
STANDALONE = auto() STANDALONE = auto()
NGRAM = auto() LOOKAHEAD = auto()
def is_none(self): def is_none(self):
return self == SpeculativeAlgorithm.NONE return self == SpeculativeAlgorithm.NONE
...@@ -20,8 +20,8 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -20,8 +20,8 @@ class SpeculativeAlgorithm(IntEnum):
def is_standalone(self): def is_standalone(self):
return self == SpeculativeAlgorithm.STANDALONE return self == SpeculativeAlgorithm.STANDALONE
def is_ngram(self): def is_lookahead(self):
return self == SpeculativeAlgorithm.NGRAM return self == SpeculativeAlgorithm.LOOKAHEAD
@staticmethod @staticmethod
def from_string(name: str): def from_string(name: str):
...@@ -29,7 +29,7 @@ class SpeculativeAlgorithm(IntEnum): ...@@ -29,7 +29,7 @@ class SpeculativeAlgorithm(IntEnum):
"EAGLE": SpeculativeAlgorithm.EAGLE, "EAGLE": SpeculativeAlgorithm.EAGLE,
"EAGLE3": SpeculativeAlgorithm.EAGLE3, "EAGLE3": SpeculativeAlgorithm.EAGLE3,
"STANDALONE": SpeculativeAlgorithm.STANDALONE, "STANDALONE": SpeculativeAlgorithm.STANDALONE,
"NGRAM": SpeculativeAlgorithm.NGRAM, "LOOKAHEAD": SpeculativeAlgorithm.LOOKAHEAD,
None: SpeculativeAlgorithm.NONE, None: SpeculativeAlgorithm.NONE,
} }
if name is not None: if name is not None:
......
...@@ -31,7 +31,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -31,7 +31,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
......
...@@ -22,7 +22,6 @@ import ctypes ...@@ -22,7 +22,6 @@ import ctypes
import dataclasses import dataclasses
import functools import functools
import importlib import importlib
import inspect
import io import io
import ipaddress import ipaddress
import itertools import itertools
...@@ -195,7 +194,7 @@ _warned_bool_env_var_keys = set() ...@@ -195,7 +194,7 @@ _warned_bool_env_var_keys = set()
def get_bool_env_var(name: str, default: str = "false") -> bool: def get_bool_env_var(name: str, default: str = "false") -> bool:
# FIXME: move your environment variable to sglang.srt.environ # FIXME: move your environment variable to sglang.environ
value = os.getenv(name, default) value = os.getenv(name, default)
value = value.lower() value = value.lower()
...@@ -213,7 +212,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool: ...@@ -213,7 +212,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
def get_int_env_var(name: str, default: int = 0) -> int: def get_int_env_var(name: str, default: int = 0) -> int:
# FIXME: move your environment variable to sglang.srt.environ # FIXME: move your environment variable to sglang.environ
value = os.getenv(name) value = os.getenv(name)
if value is None or not value.strip(): if value is None or not value.strip():
return default return default
...@@ -471,7 +470,7 @@ def is_pin_memory_available() -> bool: ...@@ -471,7 +470,7 @@ def is_pin_memory_available() -> bool:
class LayerFn(Protocol): class LayerFn(Protocol):
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ... def __call__(self, idx: int, prefix: str) -> torch.nn.Module: ...
def make_layers( def make_layers(
...@@ -482,7 +481,7 @@ def make_layers( ...@@ -482,7 +481,7 @@ def make_layers(
prefix: str = "", prefix: str = "",
return_tuple: bool = False, return_tuple: bool = False,
offloader_kwargs: Dict[str, Any] = {}, offloader_kwargs: Dict[str, Any] = {},
) -> Tuple[int, int, torch.nn.ModuleList]: ) -> Tuple[torch.nn.Module, int, int]:
"""Make a list of layers with the given layer function""" """Make a list of layers with the given layer function"""
# circula imports # circula imports
from sglang.srt.distributed import get_pp_indices from sglang.srt.distributed import get_pp_indices
...@@ -518,50 +517,6 @@ def make_layers( ...@@ -518,50 +517,6 @@ def make_layers(
return modules, start_layer, end_layer return modules, start_layer, end_layer
cmo_stream = None
def get_cmo_stream():
"""
Cache Management Operation(CMO).
Launch a new stream to prefetch the weight of matmul when running other
AIV or communication kernels, aiming to overlap the memory access time.
"""
global cmo_stream
if cmo_stream is None:
cmo_stream = torch.get_device_module().Stream()
return cmo_stream
def prepare_weight_cache(handle, cache):
import torch_npu
NPU_PREFETCH_MAX_SIZE_BYTES = (
1000000000 # 1GB, a large value to prefetch entire weight
)
stream = get_cmo_stream()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
if isinstance(cache, list):
for weight in cache:
torch_npu.npu_prefetch(
weight,
handle,
NPU_PREFETCH_MAX_SIZE_BYTES,
)
else:
torch_npu.npu_prefetch(
cache,
handle,
NPU_PREFETCH_MAX_SIZE_BYTES,
)
def wait_cmo_stream():
cur_stream = torch.get_device_module().current_stream()
cur_stream.wait_stream(get_cmo_stream())
def set_random_seed(seed: int) -> None: def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries.""" """Set the random seed for all libraries."""
random.seed(seed) random.seed(seed)
...@@ -2054,6 +2009,13 @@ def set_uvicorn_logging_configs(): ...@@ -2054,6 +2009,13 @@ def set_uvicorn_logging_configs():
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
def get_ip() -> Optional[str]:
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
if host_ip:
return host_ip
return None
def get_open_port() -> int: def get_open_port() -> int:
port = os.getenv("SGLANG_PORT") port = os.getenv("SGLANG_PORT")
if port is not None: if port is not None:
...@@ -2393,10 +2355,8 @@ def get_local_ip_auto(fallback: str = None) -> str: ...@@ -2393,10 +2355,8 @@ def get_local_ip_auto(fallback: str = None) -> str:
2. Network interface enumeration via get_local_ip_by_nic() 2. Network interface enumeration via get_local_ip_by_nic()
3. Remote connection method via get_local_ip_by_remote() 3. Remote connection method via get_local_ip_by_remote()
""" """
# Try environment variable if ip := get_ip():
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") return ip
if host_ip:
return host_ip
logger.debug("get_ip failed") logger.debug("get_ip failed")
# Fallback # Fallback
if ip := get_local_ip_by_nic(): if ip := get_local_ip_by_nic():
...@@ -2460,7 +2420,7 @@ class BumpAllocator: ...@@ -2460,7 +2420,7 @@ class BumpAllocator:
def log_info_on_rank0(logger, msg): def log_info_on_rank0(logger, msg):
from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed import get_tensor_model_parallel_rank
if torch.distributed.is_initialized() and get_tensor_model_parallel_rank() == 0: if get_tensor_model_parallel_rank() == 0:
logger.info(msg) logger.info(msg)
...@@ -3220,120 +3180,3 @@ def get_extend_input_len_swa_limit( ...@@ -3220,120 +3180,3 @@ def get_extend_input_len_swa_limit(
# and we can only free out-of-sliding-window kv indices after each prefill. # and we can only free out-of-sliding-window kv indices after each prefill.
# 3. page_size is because we want to have 1 token extra for generated tokens. # 3. page_size is because we want to have 1 token extra for generated tokens.
return page_size + 2 * max(sliding_window_size, chunked_prefill_size) return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
class CachedKernel:
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
This wrapper caches compiled Triton kernels based on keys extracted by a
user-provided key function to avoid redundant compilations.
"""
def __init__(self, fn, key_fn=None):
self.fn = fn
assert isinstance(fn, triton.runtime.jit.JITFunction)
original_fn = fn.fn
self.signature = inspect.signature(original_fn)
self.param_names = tuple(self.signature.parameters.keys())
self.num_args = len(self.param_names)
# Check that no parameters have default values
for name, param in self.signature.parameters.items():
assert (
param.default is inspect.Parameter.empty
), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
functools.update_wrapper(self, original_fn)
self.kernel_cache = {}
# Store the key function
self.key_fn = key_fn
def __getitem__(self, grid):
"""
Index with grid to get a launcher function.
Returns a launcher that will handle caching based on the key function.
"""
assert (
isinstance(grid, tuple) and len(grid) <= 3
), "Grid must be a tuple with at most 3 dimensions."
# Normalize grid once
if len(grid) < 3:
grid = grid + (1,) * (3 - len(grid))
def launcher(*args, **kwargs):
cache_key = self.key_fn(args, kwargs)
cached_kernel = self.kernel_cache.get(cache_key)
if cached_kernel is None:
# First time: compile and cache the kernel
cached_kernel = self.fn[grid](*args, **kwargs)
self.kernel_cache[cache_key] = cached_kernel
return cached_kernel
else:
# Use cached kernel
all_args = self._build_args(args, kwargs)
cached_kernel[grid](*all_args)
return cached_kernel
return launcher
def _build_args(self, args, kwargs):
"""
Build the complete argument list for kernel invocation.
"""
complete_args = list(args)
for i in range(len(args), self.num_args):
name = self.param_names[i]
value = kwargs.get(name, inspect.Parameter.empty)
if value is not inspect.Parameter.empty:
complete_args.append(value)
else:
raise ValueError(f"Missing argument: {name}")
return complete_args
def _clear_cache(self):
"""
Clear the kernel cache for testing purposes.
"""
self.kernel_cache.clear()
def cached_triton_kernel(key_fn=None):
"""
Decorator that enables key-based caching for Triton kernels using a key function.
It essentially bypasses Triton's built-in caching mechanism, allowing users to
define their own caching strategy based on kernel parameters. This helps reduce
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
is simple.
Usage:
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
@triton.jit
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
...
# Invoke normally
my_kernel[grid](x, y, BLOCK_SIZE=1024)
Args:
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
The key can be a single value or a tuple of values.
Returns:
A decorator that wraps the kernel with caching functionality.
Note: Kernels with default parameter values are not supported and will raise an assertion error.
"""
def decorator(fn):
return CachedKernel(fn, key_fn)
return decorator
...@@ -60,11 +60,6 @@ def run_eval(args): ...@@ -60,11 +60,6 @@ def run_eval(args):
from sglang.test.simple_eval_humaneval import HumanEval from sglang.test.simple_eval_humaneval import HumanEval
eval_obj = HumanEval(args.num_examples, args.num_threads) eval_obj = HumanEval(args.num_examples, args.num_threads)
elif args.eval_name == "mmmu":
# VLM MMMU evaluation with fixed 100 examples by default
from sglang.test.simple_eval_mmmu_vlm import MMMUVLMEval
eval_obj = MMMUVLMEval(args.num_examples, args.num_threads)
else: else:
raise ValueError(f"Invalid eval name: {args.eval_name}") raise ValueError(f"Invalid eval name: {args.eval_name}")
...@@ -99,8 +94,6 @@ def run_eval(args): ...@@ -99,8 +94,6 @@ def run_eval(args):
print(f"Total latency: {latency:.3f} s") print(f"Total latency: {latency:.3f} s")
print(f"Score: {metrics['score']:.3f}") print(f"Score: {metrics['score']:.3f}")
if getattr(args, "return_latency", False):
return metrics, latency
return metrics return metrics
......
...@@ -136,7 +136,7 @@ class ChatCompletionSampler(SamplerBase): ...@@ -136,7 +136,7 @@ class ChatCompletionSampler(SamplerBase):
self._pack_message("system", self.system_message) self._pack_message("system", self.system_message)
] + message_list ] + message_list
trial = 0 trial = 0
while trial < 6: # 126 seconds in total while True:
try: try:
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model, model=self.model,
......
"""
MMMU evaluation for VLMs using the run_eval simple-evals interface.
"""
from __future__ import annotations
import base64
import io
from typing import List, Optional, Tuple
from datasets import concatenate_datasets, load_dataset
from PIL import Image
from sglang.test import simple_eval_common as common
from sglang.test.simple_eval_common import (
HTML_JINJA,
Eval,
EvalResult,
SamplerBase,
SingleEvalResult,
map_with_progress,
)
class MMMUVLMEval(Eval):
DOMAIN_CAT2SUB_CAT = {
"Art and Design": ["Art", "Art_Theory", "Design", "Music"],
"Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
"Science": ["Biology", "Chemistry", "Geography", "Math", "Physics"],
"Health and Medicine": [
"Basic_Medical_Science",
"Clinical_Medicine",
"Diagnostics_and_Laboratory_Medicine",
"Pharmacy",
"Public_Health",
],
"Humanities and Social Science": [
"History",
"Literature",
"Sociology",
"Psychology",
],
"Tech and Engineering": [
"Agriculture",
"Architecture_and_Engineering",
"Computer_Science",
"Electronics",
"Energy_and_Power",
"Materials",
"Mechanical_Engineering",
],
}
def __init__(
self, num_examples: Optional[int] = 100, num_threads: int = 32, seed: int = 42
):
"""Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
self.num_examples = num_examples
self.num_threads = num_threads
self.seed = seed
# Prepare samples deterministically across all MMMU subjects (validation split)
self.samples = self._prepare_mmmu_samples(self.num_examples)
@staticmethod
def _to_data_uri(image: Image.Image) -> str:
if image.mode == "RGBA":
image = image.convert("RGB")
buf = io.BytesIO()
image.save(buf, format="PNG")
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return f"data:image/png;base64,{b64}"
@staticmethod
def _build_mc_mapping(options: List[str]) -> Tuple[dict, List[str]]:
index2ans = {}
all_choices = []
ch = ord("A")
for opt in options:
letter = chr(ch)
index2ans[letter] = opt
all_choices.append(letter)
ch += 1
return index2ans, all_choices
def _prepare_mmmu_samples(self, k: int) -> List[dict]:
# Subjects and domains copied from MMMU data_utils to categorize results
subjects: List[str] = []
for subs in self.DOMAIN_CAT2SUB_CAT.values():
subjects.extend(subs)
# Load validation split of each subject
datasets = []
for subj in subjects:
try:
d = load_dataset("MMMU/MMMU", subj, split="validation")
# attach subject info via transform
d = d.add_column("__subject__", [subj] * len(d))
datasets.append(d)
except Exception:
continue
if not datasets:
raise RuntimeError("Failed to load MMMU datasets")
merged = concatenate_datasets(datasets)
# Deterministic selection: sort by id (fallback to subject+index)
def _key(idx):
ex = merged[idx]
return str(ex.get("id", f"{ex['__subject__']}:{idx}"))
order = sorted(range(len(merged)), key=_key)
picked_indices = order[:k]
samples: List[dict] = []
for idx in picked_indices:
ex = merged[idx]
subject = ex["__subject__"]
image = ex.get("image_1")
if image is None or not hasattr(image, "convert"):
continue
data_uri = self._to_data_uri(image)
question = ex.get("question", "")
answer = ex.get("answer")
raw_options = ex.get("options")
question_type = "open"
index2ans = None
all_choices = None
options = None
if raw_options:
try:
options = (
raw_options
if isinstance(raw_options, list)
else list(eval(raw_options))
)
if isinstance(options, list) and len(options) > 0:
index2ans, all_choices = self._build_mc_mapping(options)
question_type = "multiple-choice"
except Exception:
options = None
# Build final textual prompt; include choices if MC
prompt_text = f"Question: {question}\n\n"
if options:
letters = [chr(ord("A") + i) for i in range(len(options))]
for letter, opt in zip(letters, options):
prompt_text += f"{letter}) {opt}\n"
prompt_text += "\nAnswer: "
samples.append(
{
"id": ex.get("id", f"{subject}:{idx}"),
"final_input_prompt": prompt_text,
"image_data": data_uri,
"answer": answer,
"question_type": question_type,
"index2ans": index2ans,
"all_choices": all_choices,
"category": subject,
}
)
return samples
@staticmethod
def _split_prompt_for_image(prompt: str) -> tuple[str, str]:
"""Split a prompt containing an inline image tag into prefix and suffix.
If no tag is present, treat the whole prompt as prefix and empty suffix.
"""
if "<" in prompt and ">" in prompt:
prefix = prompt.split("<")[0]
suffix = prompt.split(">", 1)[1]
return prefix, suffix
return prompt, ""
@staticmethod
def build_chat_messages_from_prompt(prompt: str, image_data) -> List:
"""Split a prompt containing an inline image tag into prefix and suffix.
If no tag is present, treat the whole prompt as prefix and empty suffix.
"""
# Build a vision+text message for OpenAI-compatible API
prefix, suffix = MMMUVLMEval._split_prompt_for_image(prompt)
content: List[dict] = []
if prefix:
content.append({"type": "text", "text": prefix})
content.append({"type": "image_url", "image_url": {"url": image_data}})
if suffix:
content.append({"type": "text", "text": suffix})
prompt_messages = [{"role": "user", "content": content}]
return prompt_messages
def __call__(self, sampler: SamplerBase) -> EvalResult:
def fn(sample: dict):
prompt = sample["final_input_prompt"]
image_data = sample["image_data"]
prompt_messages = MMMUVLMEval.build_chat_messages_from_prompt(
prompt, image_data
)
# Sample
response_text = sampler(prompt_messages)
# Parse and score
gold = sample["answer"]
if (
sample["question_type"] == "multiple-choice"
and sample["all_choices"]
and sample["index2ans"]
):
pred = _parse_multi_choice_response(
response_text, sample["all_choices"], sample["index2ans"]
)
score = 1.0 if (gold is not None and pred == gold) else 0.0
extracted_answer = pred
else:
parsed_list = _parse_open_response(response_text)
score = (
1.0 if (gold is not None and _eval_open(gold, parsed_list)) else 0.0
)
extracted_answer = ", ".join(map(str, parsed_list))
html_rendered = common.jinja_env.from_string(HTML_JINJA).render(
prompt_messages=prompt_messages,
next_message=dict(content=response_text, role="assistant"),
score=score,
correct_answer=gold,
extracted_answer=extracted_answer,
)
convo = prompt_messages + [dict(content=response_text, role="assistant")]
return SingleEvalResult(
html=html_rendered,
score=score,
metrics={"__category__": sample["category"]},
convo=convo,
)
results = map_with_progress(fn, self.samples, self.num_threads)
# Build category table and overall accuracy
# Gather per-sample correctness and category
per_cat_total: dict[str, int] = {}
per_cat_correct: dict[str, int] = {}
htmls = []
convos = []
scores: List[float] = []
for r in results:
# __category__ stored under metrics
cat = r.metrics.get("__category__") if r.metrics else None
if cat is None:
cat = "Unknown"
per_cat_total[cat] = per_cat_total.get(cat, 0) + 1
if r.score:
per_cat_correct[cat] = per_cat_correct.get(cat, 0) + 1
htmls.append(r.html)
convos.append(r.convo)
if r.score is not None:
scores.append(r.score)
evaluation_result = {}
for cat, tot in per_cat_total.items():
corr = per_cat_correct.get(cat, 0)
acc = (corr / tot) if tot > 0 else 0.0
evaluation_result[cat] = {"acc": round(acc, 3), "num_example": tot}
printable_results = {}
# Domains first
for domain, cats in self.DOMAIN_CAT2SUB_CAT.items():
acc_sum = 0.0
num_sum = 0
for cat in cats:
if cat in evaluation_result:
acc_sum += (
evaluation_result[cat]["acc"]
* evaluation_result[cat]["num_example"]
)
num_sum += evaluation_result[cat]["num_example"]
if num_sum > 0:
printable_results[f"Overall-{domain}"] = {
"num": num_sum,
"acc": round(acc_sum / num_sum, 3),
}
# add each sub-category row if present
for cat in cats:
if cat in evaluation_result:
printable_results[cat] = {
"num": evaluation_result[cat]["num_example"],
"acc": evaluation_result[cat]["acc"],
}
# Overall
total_num = sum(v["num_example"] for v in evaluation_result.values())
overall_acc = (
sum(v["acc"] * v["num_example"] for v in evaluation_result.values())
/ total_num
if total_num > 0
else 0.0
)
printable_results["Overall"] = {"num": total_num, "acc": round(overall_acc, 3)}
# Build EvalResult
return EvalResult(
score=overall_acc, metrics=printable_results, htmls=htmls, convos=convos
)
def _parse_multi_choice_response(
response: str, all_choices: List[str], index2ans: dict
) -> str:
# loosely adapted from benchmark mmmu eval
for char in [",", ".", "!", "?", ";", ":", "'"]:
response = response.strip(char)
response = " " + response + " "
# Prefer explicit letter with bracket e.g. (A)
candidates: List[str] = []
for choice in all_choices:
if f"({choice})" in response:
candidates.append(choice)
if not candidates:
for choice in all_choices:
if f" {choice} " in response:
candidates.append(choice)
if not candidates and len(response.split()) > 5:
# try match by option text
for idx, ans in index2ans.items():
if ans and ans.lower() in response.lower():
candidates.append(idx)
if not candidates:
# fallback to first choice
return all_choices[0]
if len(candidates) == 1:
return candidates[0]
# choose the last occurrence
starts = []
for can in candidates:
pos = response.rfind(f"({can})")
if pos == -1:
pos = response.rfind(f" {can} ")
if pos == -1 and index2ans.get(can):
pos = response.lower().rfind(index2ans[can].lower())
starts.append(pos)
return candidates[int(max(range(len(starts)), key=lambda i: starts[i]))]
def _check_is_number(s: str) -> bool:
try:
float(s.replace(",", ""))
return True
except Exception:
return False
def _normalize_str(s: str):
s = s.strip()
if _check_is_number(s):
s = s.replace(",", "")
try:
v = round(float(s), 2)
return [v]
except Exception:
return [s.lower()]
return [s.lower()] if len(s) > 1 else [" " + s, s + " "]
def _extract_numbers(s: str) -> List[str]:
import re as _re
pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
return (
_re.findall(pattern_commas, s)
+ _re.findall(pattern_scientific, s)
+ _re.findall(pattern_simple, s)
)
def _parse_open_response(response: str) -> List[str]:
import re as _re
def get_key_subresponses(resp: str) -> List[str]:
resp = resp.strip().strip(".").lower()
subs = _re.split(r"\.\s(?=[A-Z])|\n", resp)
indicators = [
"could be ",
"so ",
"is ",
"thus ",
"therefore ",
"final ",
"answer ",
"result ",
]
keys = []
for i, s in enumerate(subs):
cands = [*indicators]
if i == len(subs) - 1:
cands.append("=")
shortest = None
for ind in cands:
if ind in s:
part = s.split(ind)[-1].strip()
if not shortest or len(part) < len(shortest):
shortest = part
if shortest and shortest not in [":", ",", ".", "!", "?", ";", ":", "'"]:
keys.append(shortest)
return keys or [resp]
key_resps = get_key_subresponses(response)
pred_list = key_resps.copy()
for r in key_resps:
pred_list.extend(_extract_numbers(r))
out = []
for x in pred_list:
out.extend(_normalize_str(x))
# dedup
return list(dict.fromkeys(out))
def _eval_open(gold, preds: List[str]) -> bool:
if isinstance(gold, list):
norm_answers = []
for ans in gold:
norm_answers.extend(_normalize_str(ans))
else:
norm_answers = _normalize_str(gold)
for p in preds:
if isinstance(p, str):
for na in norm_answers:
if isinstance(na, str) and na in p:
return True
else:
if p in norm_answers:
return True
return False
...@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase): ...@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
w_s, w_s,
) )
from deep_gemm import fp8_m_grouped_gemm_nt_masked from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
with torch.inference_mode(): with torch.inference_mode():
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype) ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m) m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
out = oe[:, :M, :] out = oe[:, :M, :]
self.assertTrue( self.assertTrue(
......
...@@ -19,7 +19,7 @@ from sglang.profiler import run_profile ...@@ -19,7 +19,7 @@ from sglang.profiler import run_profile
PROMPT_1 = "Tell me about Richard Feynman: " PROMPT_1 = "Tell me about Richard Feynman: "
PROMPT_2 = "Generate 1000 random numbers. Go directly into it, don't say Sure and don't say here are numbers. Just start with a number." PROMPT_2 = "Generate 1000 random numbers. Go directly into it, don't say Sure and don't say here are numbers. Just start with a number."
dirpath = os.path.dirname(__file__) dirpath = os.path.dirname(__file__)
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f: with open("python/sglang/test/long_prompt.txt", "r") as f:
LONG_PROMPT = f.read() LONG_PROMPT = f.read()
......
...@@ -14,12 +14,10 @@ import time ...@@ -14,12 +14,10 @@ import time
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any, Awaitable, Callable, List, Optional, Tuple from typing import Any, Awaitable, Callable, List, Optional, Tuple
from urllib.parse import quote
import aiohttp import aiohttp
import numpy as np import numpy as np
...@@ -82,7 +80,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = ( ...@@ -82,7 +80,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct" "meta-llama/Llama-3.1-8B-Instruct"
) )
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct" DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
# Other use cases # Other use cases
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = ( DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
...@@ -1469,146 +1467,3 @@ def dump_bench_raw_result( ...@@ -1469,146 +1467,3 @@ def dump_bench_raw_result(
def _ensure_remove_suffix(text: str, suffix: str): def _ensure_remove_suffix(text: str, suffix: str):
assert text.endswith(suffix) assert text.endswith(suffix)
return text.removesuffix(suffix) return text.removesuffix(suffix)
class ModelDeploySetup:
def __init__(self, model_path: str, extra_args: List[str] = []):
self.model_path = model_path
if "--enable-multimodal" not in extra_args:
extra_args.append("--enable-multimodal")
if "--trust-remote-code" not in extra_args:
extra_args.append("--trust-remote-code")
self.extra_args = extra_args
class ModelEvalMetrics:
def __init__(self, accuracy: float, eval_time: float):
self.accuracy = accuracy
self.eval_time = eval_time
def extract_trace_link_from_bench_one_batch_server_output(output: str) -> str:
match = re.search(r"\[Profile\]\((.*?)\)", output)
if match:
trace_link = match.group(1)
return trace_link
return None
def parse_models(model_string: str):
return [model.strip() for model in model_string.split(",") if model.strip()]
def check_evaluation_test_results(
results,
test_name,
model_accuracy_thresholds,
model_latency_thresholds=None,
model_count=None,
):
"""
results: list of tuple of (model_path, accuracy, latency)
"""
failed_models = []
if model_latency_thresholds is not None:
summary = " | model | status | score | score_threshold | latency | latency_threshold | \n"
summary += "| ----- | ------ | ----- | --------------- | ------- | ----------------- | \n"
else:
summary = " | model | status | score | score_threshold | \n"
summary += "| ----- | ------ | ----- | --------------- | \n"
results_dict = {res[0]: (res[1], res[2]) for res in results}
for model, accuracy_threshold in sorted(model_accuracy_thresholds.items()):
latency_threshold = (
model_latency_thresholds.get(model)
if model_latency_thresholds is not None
else 1e9
)
if model in results_dict:
accuracy, latency = results_dict[model]
is_success = accuracy >= accuracy_threshold and latency <= latency_threshold
status_emoji = "✅" if is_success else "❌"
if not is_success:
if accuracy < accuracy_threshold:
failed_models.append(
f"\nScore Check Failed: {model}\n"
f"Model {model} score ({accuracy:.4f}) is below threshold ({accuracy_threshold:.4f})"
)
if latency > latency_threshold:
failed_models.append(
f"\nLatency Check Failed: {model}\n"
f"Model {model} latency ({latency:.4f}) is above threshold ({latency_threshold:.4f})"
)
if model_latency_thresholds is not None:
line = f"| {model} | {status_emoji} | {accuracy} | {accuracy_threshold} | {latency} | {latency_threshold}\n"
else:
line = (
f"| {model} | {status_emoji} | {accuracy} | {accuracy_threshold}\n"
)
else:
status_emoji = "❌"
failed_models.append(f"Model failed to launch or be evaluated: {model}")
if model_latency_thresholds is not None:
line = f"| {model} | {status_emoji} | N/A | {accuracy_threshold} | N/A | {latency_threshold}\n"
else:
line = f"| {model} | {status_emoji} | N/A | {accuracy_threshold}\n"
summary += line
print(summary)
if is_in_ci():
write_github_step_summary(f"## {test_name}\n{summary}")
if failed_models:
print("Some models failed the evaluation.")
raise AssertionError("\n".join(failed_models))
# Bench knobs for bench_one_batch_server (override by env)
def _parse_int_list_env(name: str, default_val: str):
val = os.environ.get(name, default_val)
return [int(x) for x in val.split(",") if x]
# Return filenames
def find_traces_under_path(path: str) -> List[str]:
results = []
for _, dirs, files in os.walk(path):
for file in files:
if file.endswith(".trace.json.gz"):
results.append(f"{file}")
return results
def write_results_to_json(model, metrics, mode="a"):
result = {
"timestamp": datetime.now().isoformat(),
"model": model,
"metrics": metrics,
"score": metrics["score"],
}
if "latency" in metrics:
result["latency"] = (metrics.get("latency"),)
existing_results = []
if mode == "a" and os.path.exists("results.json"):
try:
with open("results.json", "r") as f:
existing_results = json.load(f)
except json.JSONDecodeError:
existing_results = []
if isinstance(existing_results, list):
existing_results.append(result)
else:
existing_results = [result]
with open("results.json", "w") as f:
json.dump(existing_results, f, indent=2)
"""Common utilities""" """Common utilities"""
import functools
import importlib import importlib
import inspect
import json import json
import logging import logging
import os import os
import random import random
import socket import socket
import ssl
import subprocess import subprocess
import sys import sys
import time import time
...@@ -22,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union ...@@ -22,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union
import numpy as np import numpy as np
import pybase64 import pybase64
import requests import requests
import triton
from IPython.display import HTML, display from IPython.display import HTML, display
from pydantic import BaseModel from pydantic import BaseModel
from tqdm import tqdm from tqdm import tqdm
...@@ -156,15 +158,7 @@ def http_request( ...@@ -156,15 +158,7 @@ def http_request(
data = bytes(dumps(json), encoding="utf-8") data = bytes(dumps(json), encoding="utf-8")
try: try:
if sys.version_info >= (3, 13): resp = urllib.request.urlopen(req, data=data, cafile=verify)
# Python 3.13+: Use SSL context (cafile removed)
if verify and isinstance(verify, str):
context = ssl.create_default_context(cafile=verify)
else:
context = ssl.create_default_context()
resp = urllib.request.urlopen(req, data=data, context=context)
else:
resp = urllib.request.urlopen(req, data=data, cafile=verify)
return HttpResponse(resp) return HttpResponse(resp)
except urllib.error.HTTPError as e: except urllib.error.HTTPError as e:
return HttpResponse(e) return HttpResponse(e)
...@@ -549,3 +543,114 @@ def resolve_obj_by_qualname(qualname: str) -> Any: ...@@ -549,3 +543,114 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
module_name, obj_name = qualname.rsplit(".", 1) module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
return getattr(module, obj_name) return getattr(module, obj_name)
class CachedKernel:
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
This wrapper caches compiled Triton kernels based on keys extracted by a
user-provided key function to avoid redundant compilations.
"""
def __init__(self, fn, key_fn=None):
self.fn = fn
assert isinstance(fn, triton.runtime.jit.JITFunction)
original_fn = fn.fn
self.signature = inspect.signature(original_fn)
self.param_names = tuple(self.signature.parameters.keys())
self.num_args = len(self.param_names)
# Check that no parameters have default values
for name, param in self.signature.parameters.items():
assert (
param.default is inspect.Parameter.empty
), f"Parameter '{name}' has a default value. Default parameters are not supported in cached kernels."
functools.update_wrapper(self, original_fn)
self.kernel_cache = {}
# Store the key function
self.key_fn = key_fn
def __getitem__(self, grid):
"""
Index with grid to get a launcher function.
Returns a launcher that will handle caching based on the key function.
"""
assert (
isinstance(grid, tuple) and len(grid) <= 3
), "Grid must be a tuple with at most 3 dimensions."
# Normalize grid once
if len(grid) < 3:
grid = grid + (1,) * (3 - len(grid))
def launcher(*args, **kwargs):
cache_key = self.key_fn(args, kwargs)
cached_kernel = self.kernel_cache.get(cache_key)
if cached_kernel is None:
# First time: compile and cache the kernel
cached_kernel = self.fn[grid](*args, **kwargs)
self.kernel_cache[cache_key] = cached_kernel
return cached_kernel
else:
# Use cached kernel
all_args = self._build_args(args, kwargs)
cached_kernel[grid](*all_args)
return cached_kernel
return launcher
def _build_args(self, args, kwargs):
"""
Build the complete argument list for kernel invocation.
"""
complete_args = list(args)
for i in range(len(args), self.num_args):
name = self.param_names[i]
value = kwargs.get(name, inspect.Parameter.empty)
if value is not inspect.Parameter.empty:
complete_args.append(value)
else:
raise ValueError(f"Missing argument: {name}")
return complete_args
def cached_triton_kernel(key_fn=None):
"""
Decorator that enables key-based caching for Triton kernels using a key function.
It essentially bypasses Triton's built-in caching mechanism, allowing users to
define their own caching strategy based on kernel parameters. This helps reduce
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
is simple.
Usage:
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
@triton.jit
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
...
# Invoke normally
my_kernel[grid](x, y, BLOCK_SIZE=1024)
Args:
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
The key can be a single value or a tuple of values.
Returns:
A decorator that wraps the kernel with caching functionality.
Note: Kernels with default parameter values are not supported and will raise an assertion error.
"""
def decorator(fn):
return CachedKernel(fn, key_fn)
return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment