Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
852a49c5
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "084669a43ecf6f5ad699dfcef6236de7135f2ad6"
Commit
852a49c5
authored
Sep 30, 2025
by
maxiao
Browse files
adapt to dsv32 on dcu
parent
8f7453e3
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
226 additions
and
870 deletions
+226
-870
python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp
python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp
+19
-18
python/sglang/srt/speculative/cpp_lookahead/lookahead.h
python/sglang/srt/speculative/cpp_lookahead/lookahead.h
+8
-8
python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py
...n/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py
+10
-8
python/sglang/srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp
...srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp
+12
-12
python/sglang/srt/speculative/cpp_lookahead/param.h
python/sglang/srt/speculative/cpp_lookahead/param.h
+2
-2
python/sglang/srt/speculative/cpp_lookahead/queue.h
python/sglang/srt/speculative/cpp_lookahead/queue.h
+0
-0
python/sglang/srt/speculative/eagle_utils.py
python/sglang/srt/speculative/eagle_utils.py
+13
-15
python/sglang/srt/speculative/lookahead_utils.py
python/sglang/srt/speculative/lookahead_utils.py
+3
-3
python/sglang/srt/speculative/lookahead_worker.py
python/sglang/srt/speculative/lookahead_worker.py
+20
-20
python/sglang/srt/speculative/spec_info.py
python/sglang/srt/speculative/spec_info.py
+4
-4
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+14
-171
python/sglang/test/run_eval.py
python/sglang/test/run_eval.py
+0
-7
python/sglang/test/simple_eval_common.py
python/sglang/test/simple_eval_common.py
+1
-1
python/sglang/test/simple_eval_mmmu_vlm.py
python/sglang/test/simple_eval_mmmu_vlm.py
+0
-441
python/sglang/test/test_block_fp8.py
python/sglang/test/test_block_fp8.py
+2
-2
python/sglang/test/test_deterministic.py
python/sglang/test/test_deterministic.py
+1
-1
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-146
python/sglang/utils.py
python/sglang/utils.py
+115
-10
No files found.
python/sglang/srt/speculative/cpp_
ngram/ngram
.cpp
→
python/sglang/srt/speculative/cpp_
lookahead/lookahead
.cpp
View file @
852a49c5
#include "
ngram
.h"
#include "
lookahead
.h"
#include <limits>
#include <limits>
#include <vector>
#include <vector>
namespace
ngram
{
namespace
lookahead
{
struct
Node
{
struct
Node
{
std
::
unordered_map
<
int32_t
,
int32_t
>
next
;
std
::
unordered_map
<
int32_t
,
int32_t
>
next
;
};
};
Ngram
::
Result
fillResult
(
int
last_token
,
int
draft_token_num
,
std
::
vector
<
Node
>&
tree
,
int
root
)
{
Lookahead
::
Result
fillResult
(
int
last_token
,
int
draft_token_num
,
std
::
vector
<
Node
>&
tree
,
int
root
)
{
Ngram
::
Result
info
;
Lookahead
::
Result
info
;
std
::
vector
<
int32_t
>
prevs
;
std
::
vector
<
int32_t
>
prevs
;
info
.
token
.
reserve
(
draft_token_num
);
info
.
token
.
reserve
(
draft_token_num
);
prevs
.
reserve
(
draft_token_num
);
prevs
.
reserve
(
draft_token_num
);
...
@@ -50,7 +50,7 @@ Ngram::Result fillResult(int last_token, int draft_token_num, std::vector<Node>&
...
@@ -50,7 +50,7 @@ Ngram::Result fillResult(int last_token, int draft_token_num, std::vector<Node>&
return
info
;
return
info
;
}
}
Ngram
::
Ngram
(
size_t
capacity
,
const
Param
&
param
)
{
Lookahead
::
Lookahead
(
size_t
capacity
,
const
Param
&
param
)
{
param_
=
param
;
param_
=
param
;
nodes_
.
resize
(
capacity
);
nodes_
.
resize
(
capacity
);
for
(
auto
&
node
:
nodes_
)
{
for
(
auto
&
node
:
nodes_
)
{
...
@@ -116,16 +116,17 @@ Ngram::Ngram(size_t capacity, const Param& param) {
...
@@ -116,16 +116,17 @@ Ngram::Ngram(size_t capacity, const Param& param) {
}
}
quit_flag_
=
false
;
quit_flag_
=
false
;
insert_worker_
=
std
::
thread
(
&
Ngram
::
insert
,
this
);
insert_worker_
=
std
::
thread
(
&
Lookahead
::
insert
,
this
);
}
}
Ngram
::~
Ngram
()
{
Lookahead
::~
Lookahead
()
{
quit_flag_
=
true
;
quit_flag_
=
true
;
insert_queue_
.
close
();
insert_queue_
.
close
();
insert_worker_
.
join
();
insert_worker_
.
join
();
}
}
std
::
vector
<
std
::
pair
<
TrieNode
*
,
int32_t
>>
Ngram
::
match
(
const
std
::
vector
<
int32_t
>&
tokens
,
size_t
batch_size
)
const
{
std
::
vector
<
std
::
pair
<
TrieNode
*
,
int32_t
>>
Lookahead
::
match
(
const
std
::
vector
<
int32_t
>&
tokens
,
size_t
batch_size
)
const
{
auto
draft_token_num
=
param_
.
get_draft_token_num
(
batch_size
);
auto
draft_token_num
=
param_
.
get_draft_token_num
(
batch_size
);
auto
min_match_window_size
=
param_
.
get_min_match_window_size
(
batch_size
);
auto
min_match_window_size
=
param_
.
get_min_match_window_size
(
batch_size
);
auto
max_match_window_size
=
param_
.
max_match_window_size
;
auto
max_match_window_size
=
param_
.
max_match_window_size
;
...
@@ -153,7 +154,7 @@ std::vector<std::pair<TrieNode*, int32_t>> Ngram::match(const std::vector<int32_
...
@@ -153,7 +154,7 @@ std::vector<std::pair<TrieNode*, int32_t>> Ngram::match(const std::vector<int32_
return
result
;
return
result
;
}
}
void
Ngram
::
squeeze
(
size_t
count
)
{
void
Lookahead
::
squeeze
(
size_t
count
)
{
if
(
!
(
node_pool_
.
size
()
>=
free_node_count_
+
count
))
{
if
(
!
(
node_pool_
.
size
()
>=
free_node_count_
+
count
))
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"Insufficient node size to release required nodes. "
"Insufficient node size to release required nodes. "
...
@@ -176,13 +177,13 @@ void Ngram::squeeze(size_t count) {
...
@@ -176,13 +177,13 @@ void Ngram::squeeze(size_t count) {
}
}
}
}
void
Ngram
::
synchronize
()
const
{
void
Lookahead
::
synchronize
()
const
{
while
(
!
insert_queue_
.
empty
())
{
while
(
!
insert_queue_
.
empty
())
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
microseconds
(
10
));
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
microseconds
(
10
));
}
}
}
}
void
Ngram
::
insert
()
{
void
Lookahead
::
insert
()
{
while
(
!
quit_flag_
)
{
while
(
!
quit_flag_
)
{
std
::
vector
<
int32_t
>
data
;
std
::
vector
<
int32_t
>
data
;
if
(
!
insert_queue_
.
dequeue
(
data
))
{
if
(
!
insert_queue_
.
dequeue
(
data
))
{
...
@@ -238,13 +239,13 @@ void Ngram::insert() {
...
@@ -238,13 +239,13 @@ void Ngram::insert() {
}
}
}
}
void
Ngram
::
asyncInsert
(
std
::
vector
<
std
::
vector
<
int32_t
>>&&
tokens
)
{
void
Lookahead
::
asyncInsert
(
std
::
vector
<
std
::
vector
<
int32_t
>>&&
tokens
)
{
for
(
auto
&&
token
:
tokens
)
{
for
(
auto
&&
token
:
tokens
)
{
insert_queue_
.
enqueue
(
std
::
move
(
token
));
insert_queue_
.
enqueue
(
std
::
move
(
token
));
}
}
}
}
Ngram
::
Result
Ngram
::
matchBFS
(
const
std
::
vector
<
int32_t
>&
tokens
,
size_t
batch_size
)
const
{
Lookahead
::
Result
Lookahead
::
matchBFS
(
const
std
::
vector
<
int32_t
>&
tokens
,
size_t
batch_size
)
const
{
std
::
vector
<
std
::
pair
<
TrieNode
*
,
int32_t
>>
nodes
=
match
(
tokens
,
batch_size
);
std
::
vector
<
std
::
pair
<
TrieNode
*
,
int32_t
>>
nodes
=
match
(
tokens
,
batch_size
);
double
bfs_breadth_scale
=
double
(
param_
.
max_bfs_breadth
-
param_
.
min_bfs_breadth
)
/
double
bfs_breadth_scale
=
double
(
param_
.
max_bfs_breadth
-
param_
.
min_bfs_breadth
)
/
...
@@ -283,7 +284,7 @@ Ngram::Result Ngram::matchBFS(const std::vector<int32_t>& tokens, size_t batch_s
...
@@ -283,7 +284,7 @@ Ngram::Result Ngram::matchBFS(const std::vector<int32_t>& tokens, size_t batch_s
return
fillResult
(
tokens
.
back
(),
draft_token_num
+
1
,
tree
,
root
);
return
fillResult
(
tokens
.
back
(),
draft_token_num
+
1
,
tree
,
root
);
}
}
Ngram
::
Result
Ngram
::
matchProb
(
const
std
::
vector
<
int32_t
>&
tokens
,
size_t
batch_size
)
const
{
Lookahead
::
Result
Lookahead
::
matchProb
(
const
std
::
vector
<
int32_t
>&
tokens
,
size_t
batch_size
)
const
{
std
::
vector
<
std
::
pair
<
TrieNode
*
,
int32_t
>>
nodes
=
match
(
tokens
,
batch_size
);
std
::
vector
<
std
::
pair
<
TrieNode
*
,
int32_t
>>
nodes
=
match
(
tokens
,
batch_size
);
auto
draft_token_num
=
param_
.
get_draft_token_num
(
batch_size
);
auto
draft_token_num
=
param_
.
get_draft_token_num
(
batch_size
);
...
@@ -345,10 +346,10 @@ Ngram::Result Ngram::matchProb(const std::vector<int32_t>& tokens, size_t batch_
...
@@ -345,10 +346,10 @@ Ngram::Result Ngram::matchProb(const std::vector<int32_t>& tokens, size_t batch_
return
fillResult
(
tokens
.
back
(),
draft_token_num
+
1
,
tree
,
root
);
return
fillResult
(
tokens
.
back
(),
draft_token_num
+
1
,
tree
,
root
);
}
}
Ngram
::
Result
Ngram
::
batchMatch
(
const
std
::
vector
<
std
::
vector
<
int32_t
>>&
tokens
)
const
{
Lookahead
::
Result
Lookahead
::
batchMatch
(
const
std
::
vector
<
std
::
vector
<
int32_t
>>&
tokens
)
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
Result
merged_result
;
Result
merged_result
;
auto
match_func
=
param_
.
match_type
==
"BFS"
?
&
Ngram
::
matchBFS
:
&
Ngram
::
matchProb
;
auto
match_func
=
param_
.
match_type
==
"BFS"
?
&
Lookahead
::
matchBFS
:
&
Lookahead
::
matchProb
;
for
(
const
auto
&
tks
:
tokens
)
{
for
(
const
auto
&
tks
:
tokens
)
{
Result
res
=
(
this
->*
match_func
)(
tks
,
tokens
.
size
());
Result
res
=
(
this
->*
match_func
)(
tks
,
tokens
.
size
());
merged_result
.
token
.
insert
(
merged_result
.
token
.
end
(),
res
.
token
.
begin
(),
res
.
token
.
end
());
merged_result
.
token
.
insert
(
merged_result
.
token
.
end
(),
res
.
token
.
begin
(),
res
.
token
.
end
());
...
@@ -357,7 +358,7 @@ Ngram::Result Ngram::batchMatch(const std::vector<std::vector<int32_t>>& tokens)
...
@@ -357,7 +358,7 @@ Ngram::Result Ngram::batchMatch(const std::vector<std::vector<int32_t>>& tokens)
return
merged_result
;
return
merged_result
;
}
}
void
Ngram
::
Result
::
truncate
(
size_t
n
)
{
void
Lookahead
::
Result
::
truncate
(
size_t
n
)
{
if
(
n
<
token
.
size
())
{
if
(
n
<
token
.
size
())
{
int
full_n
=
token
.
size
();
int
full_n
=
token
.
size
();
for
(
int
i
=
1
;
i
<
n
;
++
i
)
{
for
(
int
i
=
1
;
i
<
n
;
++
i
)
{
...
@@ -368,4 +369,4 @@ void Ngram::Result::truncate(size_t n) {
...
@@ -368,4 +369,4 @@ void Ngram::Result::truncate(size_t n) {
}
}
}
}
}
// namespace
ngram
}
// namespace
lookahead
python/sglang/srt/speculative/cpp_
ngram/ngram
.h
→
python/sglang/srt/speculative/cpp_
lookahead/lookahead
.h
View file @
852a49c5
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include "param.h"
#include "param.h"
#include "queue.h"
#include "queue.h"
namespace
ngram
{
namespace
lookahead
{
struct
TrieNode
{
struct
TrieNode
{
std
::
unordered_map
<
int32_t
,
TrieNode
*>
child
;
std
::
unordered_map
<
int32_t
,
TrieNode
*>
child
;
...
@@ -34,7 +34,7 @@ struct TrieNode {
...
@@ -34,7 +34,7 @@ struct TrieNode {
std
::
multiset
<
TrieNode
*
,
CompareByFreq
>
sorted_children
;
std
::
multiset
<
TrieNode
*
,
CompareByFreq
>
sorted_children
;
};
};
class
Ngram
{
class
Lookahead
{
std
::
vector
<
TrieNode
>
nodes_
;
std
::
vector
<
TrieNode
>
nodes_
;
std
::
vector
<
TrieNode
*>
node_pool_
;
std
::
vector
<
TrieNode
*>
node_pool_
;
size_t
free_node_count_
;
size_t
free_node_count_
;
...
@@ -61,12 +61,12 @@ class Ngram {
...
@@ -61,12 +61,12 @@ class Ngram {
std
::
vector
<
std
::
tuple
<
int32_t
,
int32_t
,
int32_t
,
int32_t
>>
match_tmp_data_
;
std
::
vector
<
std
::
tuple
<
int32_t
,
int32_t
,
int32_t
,
int32_t
>>
match_tmp_data_
;
public:
public:
Ngram
(
size_t
capacity
,
const
Param
&
param
);
Lookahead
(
size_t
capacity
,
const
Param
&
param
);
Ngram
()
=
default
;
Lookahead
()
=
default
;
~
Ngram
();
~
Lookahead
();
static
Ngram
&
instance
()
{
static
Lookahead
&
instance
()
{
static
Ngram
instance
;
static
Lookahead
instance
;
return
instance
;
return
instance
;
}
}
...
@@ -107,4 +107,4 @@ class Ngram {
...
@@ -107,4 +107,4 @@ class Ngram {
void
insert
();
void
insert
();
};
};
}
// namespace
ngram
}
// namespace
lookahead
python/sglang/srt/speculative/cpp_
ngram/ngram
_cache.py
→
python/sglang/srt/speculative/cpp_
lookahead/lookahead
_cache.py
View file @
852a49c5
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# from sglang.op.lookahead import Lookahead, Param
import
logging
import
logging
import
os
import
os
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
...
@@ -10,17 +12,17 @@ from torch.utils.cpp_extension import load
...
@@ -10,17 +12,17 @@ from torch.utils.cpp_extension import load
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_abs_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
_abs_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
ngram
_cache_cpp
=
load
(
lookahead
_cache_cpp
=
load
(
name
=
"
ngram
_cache_cpp"
,
name
=
"
lookahead
_cache_cpp"
,
sources
=
[
sources
=
[
f
"
{
_abs_path
}
/
ngram
_cache_binding.cpp"
,
f
"
{
_abs_path
}
/
lookahead
_cache_binding.cpp"
,
f
"
{
_abs_path
}
/
ngram
.cpp"
,
f
"
{
_abs_path
}
/
lookahead
.cpp"
,
],
],
extra_cflags
=
[
"-O3"
,
"-std=c++20"
],
extra_cflags
=
[
"-O3"
,
"-std=c++20"
],
)
)
class
Ngram
Cache
:
class
Lookahead
Cache
:
def
__init__
(
def
__init__
(
self
,
self
,
branch_length
=
18
,
branch_length
=
18
,
...
@@ -32,7 +34,7 @@ class NgramCache:
...
@@ -32,7 +34,7 @@ class NgramCache:
match_type
=
"BFS"
,
match_type
=
"BFS"
,
capacity
=
1000000
,
capacity
=
1000000
,
):
):
param
=
ngram
_cache_cpp
.
Param
()
param
=
lookahead
_cache_cpp
.
Param
()
param
.
branch_length
=
branch_length
param
.
branch_length
=
branch_length
param
.
min_match_window_size
=
min_match_window_size
param
.
min_match_window_size
=
min_match_window_size
param
.
max_match_window_size
=
max_match_window_size
param
.
max_match_window_size
=
max_match_window_size
...
@@ -40,7 +42,7 @@ class NgramCache:
...
@@ -40,7 +42,7 @@ class NgramCache:
param
.
max_bfs_breadth
=
max_bfs_breadth
param
.
max_bfs_breadth
=
max_bfs_breadth
param
.
draft_token_num
=
draft_token_num
param
.
draft_token_num
=
draft_token_num
param
.
match_type
=
match_type
param
.
match_type
=
match_type
self
.
cache
=
ngram
_cache_cpp
.
Ngram
(
capacity
,
param
)
self
.
cache
=
lookahead
_cache_cpp
.
Lookahead
(
capacity
,
param
)
self
.
default_mask
=
np
.
ones
((
1
,
1
),
dtype
=
np
.
int64
)
self
.
default_mask
=
np
.
ones
((
1
,
1
),
dtype
=
np
.
int64
)
self
.
draft_token_num
=
draft_token_num
self
.
draft_token_num
=
draft_token_num
...
@@ -129,7 +131,7 @@ if __name__ == "__main__":
...
@@ -129,7 +131,7 @@ if __name__ == "__main__":
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
[
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
],
[
1
,
2
,
3
,
44
,
55
,
66
,
77
,
88
,
99
,
100
],
[
1
,
2
,
3
,
44
,
55
,
66
,
77
,
88
,
99
,
100
],
]
]
cache
=
Ngram
Cache
(
branch_length
=
12
,
draft_token_num
=
8
)
cache
=
Lookahead
Cache
(
branch_length
=
12
,
draft_token_num
=
8
)
cache
.
batch_put
(
token_ids
)
cache
.
batch_put
(
token_ids
)
cache
.
synchronize
()
cache
.
synchronize
()
...
...
python/sglang/srt/speculative/cpp_
ngram/ngram
_cache_binding.cpp
→
python/sglang/srt/speculative/cpp_
lookahead/lookahead
_cache_binding.cpp
View file @
852a49c5
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/stl.h>
#include "
ngram
.h"
#include "
lookahead
.h"
PYBIND11_MODULE
(
ngram
_cache_cpp
,
m
)
{
PYBIND11_MODULE
(
lookahead
_cache_cpp
,
m
)
{
using
namespace
ngram
;
using
namespace
lookahead
;
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
m
.
doc
()
=
""
;
m
.
doc
()
=
""
;
py
::
class_
<
Ngram
>
(
m
,
"Ngram
"
)
py
::
class_
<
Lookahead
>
(
m
,
"Lookahead
"
)
.
def
(
py
::
init
<
size_t
,
const
Param
&>
(),
py
::
arg
(
"capacity"
),
py
::
arg
(
"param"
))
.
def
(
py
::
init
<
size_t
,
const
Param
&>
(),
py
::
arg
(
"capacity"
),
py
::
arg
(
"param"
))
.
def
(
"asyncInsert"
,
&
Ngram
::
asyncInsert
,
""
)
.
def
(
"asyncInsert"
,
&
Lookahead
::
asyncInsert
,
""
)
.
def
(
"batchMatch"
,
&
Ngram
::
batchMatch
,
""
)
.
def
(
"batchMatch"
,
&
Lookahead
::
batchMatch
,
""
)
.
def
(
"reset"
,
&
Ngram
::
reset
,
""
)
.
def
(
"reset"
,
&
Lookahead
::
reset
,
""
)
.
def
(
"synchronize"
,
&
Ngram
::
synchronize
,
""
);
.
def
(
"synchronize"
,
&
Lookahead
::
synchronize
,
""
);
py
::
class_
<
Param
>
(
m
,
"Param"
)
py
::
class_
<
Param
>
(
m
,
"Param"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
...
@@ -35,9 +35,9 @@ PYBIND11_MODULE(ngram_cache_cpp, m) {
...
@@ -35,9 +35,9 @@ PYBIND11_MODULE(ngram_cache_cpp, m) {
.
def
(
"resetBatchReturnTokenNum"
,
&
Param
::
resetBatchReturnTokenNum
,
""
)
.
def
(
"resetBatchReturnTokenNum"
,
&
Param
::
resetBatchReturnTokenNum
,
""
)
.
def
(
"detail"
,
&
Param
::
detail
,
""
);
.
def
(
"detail"
,
&
Param
::
detail
,
""
);
py
::
class_
<
Ngram
::
Result
>
(
m
,
"Result"
)
py
::
class_
<
Lookahead
::
Result
>
(
m
,
"Result"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"token"
,
&
Ngram
::
Result
::
token
)
.
def_readwrite
(
"token"
,
&
Lookahead
::
Result
::
token
)
.
def_readwrite
(
"mask"
,
&
Ngram
::
Result
::
mask
)
.
def_readwrite
(
"mask"
,
&
Lookahead
::
Result
::
mask
)
.
def
(
"truncate"
,
&
Ngram
::
Result
::
truncate
);
.
def
(
"truncate"
,
&
Lookahead
::
Result
::
truncate
);
}
}
python/sglang/srt/speculative/cpp_
ngram
/param.h
→
python/sglang/srt/speculative/cpp_
lookahead
/param.h
View file @
852a49c5
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include <string>
#include <string>
#include <vector>
#include <vector>
namespace
ngram
{
namespace
lookahead
{
struct
Param
{
struct
Param
{
bool
enable
;
bool
enable
;
...
@@ -122,4 +122,4 @@ struct Param {
...
@@ -122,4 +122,4 @@ struct Param {
}
}
};
};
}
// namespace
ngram
}
// namespace
lookahead
python/sglang/srt/speculative/cpp_
ngram
/queue.h
→
python/sglang/srt/speculative/cpp_
lookahead
/queue.h
View file @
852a49c5
File moved
python/sglang/srt/speculative/eagle_utils.py
View file @
852a49c5
...
@@ -13,7 +13,6 @@ import triton
...
@@ -13,7 +13,6 @@ import triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.constrained.base_grammar_backend
import
BaseGrammarObject
from
sglang.srt.environ
import
envs
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.attention.utils
import
create_flashinfer_kv_indices_triton
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
from
sglang.srt.layers.sampler
import
apply_custom_logit_processor
...
@@ -24,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
...
@@ -24,7 +23,7 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict
,
global_server_args_dict
,
)
)
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.mem_cache.allocator
import
BaseTokenToKVPoolAllocator
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
from
sglang.srt.model_executor.forward_batch_info
import
CaptureHiddenMode
,
ForwardMode
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
from
sglang.srt.utils
import
is_cuda
,
is_hip
,
next_power_of_2
if
is_cuda
():
if
is_cuda
():
...
@@ -43,8 +42,8 @@ logger = logging.getLogger(__name__)
...
@@ -43,8 +42,8 @@ logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN
=
envs
.
SGLANG_SIMULATE_ACC_LEN
.
get
()
# turn off if < 0
SIMULATE_ACC_LEN
=
os
.
environ
.
get
(
"SIMULATE_ACC_LEN"
)
SIMULATE_ACC_METHOD
=
envs
.
SGLANG_
SIMULATE_ACC_METHOD
.
get
(
)
SIMULATE_ACC_METHOD
=
os
.
environ
.
get
(
"
SIMULATE_ACC_METHOD
"
,
"multinomial"
)
TREE_TRAVERSE_TIME_THRESHOLD
=
1
# TODO: set this properly
TREE_TRAVERSE_TIME_THRESHOLD
=
1
# TODO: set this properly
...
@@ -501,12 +500,13 @@ class EagleVerifyInput:
...
@@ -501,12 +500,13 @@ class EagleVerifyInput:
deterministic
=
True
,
deterministic
=
True
,
)
)
if
SIMULATE_ACC_LEN
>
0.0
:
if
SIMULATE_ACC_LEN
:
# Do simulation
# Do simulation
accept_index
=
_generate_simulated_accept_index
(
accept_index
=
_generate_simulated_accept_index
(
accept_index
=
accept_index
,
accept_index
=
accept_index
,
predict
=
predict
,
# mutable
predict
=
predict
,
# mutable
accept_length
=
accept_length
,
# mutable
accept_length
=
accept_length
,
# mutable
simulate_acc_len
=
SIMULATE_ACC_LEN
,
bs
=
bs
,
bs
=
bs
,
spec_steps
=
self
.
spec_steps
,
spec_steps
=
self
.
spec_steps
,
)
)
...
@@ -1131,16 +1131,14 @@ def _generate_simulated_accept_index(
...
@@ -1131,16 +1131,14 @@ def _generate_simulated_accept_index(
accept_index
,
accept_index
,
predict
,
predict
,
accept_length
,
accept_length
,
simulate_acc_len
,
bs
,
bs
,
spec_steps
,
spec_steps
,
simulate_acc_len
:
float
=
SIMULATE_ACC_LEN
,
simulate_acc_method
:
str
=
SIMULATE_ACC_METHOD
,
):
):
assert
simulate_acc_len
>
0.0
simulate_acc_len_float
=
float
(
simulate_acc_len
)
if
SIMULATE_ACC_METHOD
==
"multinomial"
:
if
simulate_acc_method
==
"multinomial"
:
simulated_values
=
torch
.
normal
(
simulated_values
=
torch
.
normal
(
mean
=
simulate_acc_len
,
mean
=
simulate_acc_len
_float
,
std
=
1.0
,
std
=
1.0
,
size
=
(
1
,),
size
=
(
1
,),
device
=
"cpu"
,
device
=
"cpu"
,
...
@@ -1148,19 +1146,19 @@ def _generate_simulated_accept_index(
...
@@ -1148,19 +1146,19 @@ def _generate_simulated_accept_index(
# clamp simulated values to be between 1 and self.spec_steps
# clamp simulated values to be between 1 and self.spec_steps
simulated_values
=
torch
.
clamp
(
simulated_values
,
min
=
1.0
,
max
=
spec_steps
+
1
)
simulated_values
=
torch
.
clamp
(
simulated_values
,
min
=
1.0
,
max
=
spec_steps
+
1
)
simulate_acc_len
=
int
(
simulated_values
.
round
().
item
())
simulate_acc_len
=
int
(
simulated_values
.
round
().
item
())
elif
simulate_acc_method
==
"match-expected"
:
elif
SIMULATE_ACC_METHOD
==
"match-expected"
:
# multinomial sampling does not match the expected length
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
# either round down or round up of the expected length
simulate_acc_len
=
max
(
1.0
,
min
(
spec_steps
+
1
,
simulate_acc_len
))
simulate_acc_len
_float
=
max
(
1.0
,
min
(
spec_steps
+
1
,
simulate_acc_len
_float
))
lower
=
int
(
simulate_acc_len
//
1
)
lower
=
int
(
simulate_acc_len
_float
//
1
)
upper
=
lower
+
1
if
lower
<
spec_steps
+
1
else
lower
upper
=
lower
+
1
if
lower
<
spec_steps
+
1
else
lower
if
lower
==
upper
:
if
lower
==
upper
:
simulate_acc_len
=
lower
simulate_acc_len
=
lower
else
:
else
:
weight_upper
=
simulate_acc_len
-
lower
weight_upper
=
simulate_acc_len
_float
-
lower
weight_lower
=
1.0
-
weight_upper
weight_lower
=
1.0
-
weight_upper
probs
=
torch
.
tensor
([
weight_lower
,
weight_upper
],
device
=
"cpu"
)
probs
=
torch
.
tensor
([
weight_lower
,
weight_upper
],
device
=
"cpu"
)
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
sampled_index
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
...
...
python/sglang/srt/speculative/
ngram
_utils.py
→
python/sglang/srt/speculative/
lookahead
_utils.py
View file @
852a49c5
...
@@ -42,7 +42,7 @@ elif is_hip():
...
@@ -42,7 +42,7 @@ elif is_hip():
@
dataclass
@
dataclass
class
Ngram
VerifyInput
:
class
Lookahead
VerifyInput
:
def
__init__
(
def
__init__
(
self
,
self
,
draft_token
:
torch
.
Tensor
,
draft_token
:
torch
.
Tensor
,
...
@@ -405,8 +405,8 @@ class NgramVerifyInput:
...
@@ -405,8 +405,8 @@ class NgramVerifyInput:
return
logits_output
,
self
.
verified_id
,
self
.
accept_length
.
sum
().
item
()
return
logits_output
,
self
.
verified_id
,
self
.
accept_length
.
sum
().
item
()
def
filter_batch
(
self
,
new_indices
:
torch
.
Tensor
,
has_been_filtered
:
bool
=
True
):
def
filter_batch
(
self
,
new_indices
:
torch
.
Tensor
):
pass
pass
def
merge_batch
(
self
,
spec_info
:
Ngram
VerifyInput
):
def
merge_batch
(
self
,
spec_info
:
Lookahead
VerifyInput
):
pass
pass
python/sglang/srt/speculative/
ngram
_worker.py
→
python/sglang/srt/speculative/
lookahead
_worker.py
View file @
852a49c5
...
@@ -12,8 +12,8 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
...
@@ -12,8 +12,8 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.speculative.cpp_
ngram.ngram
_cache
import
Ngram
Cache
from
sglang.srt.speculative.cpp_
lookahead.lookahead
_cache
import
Lookahead
Cache
from
sglang.srt.speculative.
ngram
_utils
import
Ngram
VerifyInput
from
sglang.srt.speculative.
lookahead
_utils
import
Lookahead
VerifyInput
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.utils
import
broadcast_pyobj
from
sglang.srt.utils
import
broadcast_pyobj
...
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
...
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
USE_FULL_MASK
=
True
USE_FULL_MASK
=
True
class
NGRAM
Worker
:
class
LOOKAHEAD
Worker
:
def
__init__
(
def
__init__
(
self
,
self
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
...
@@ -38,9 +38,9 @@ class NGRAMWorker:
...
@@ -38,9 +38,9 @@ class NGRAMWorker:
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
page_size
=
server_args
.
page_size
self
.
page_size
=
server_args
.
page_size
self
.
draft_token_num
:
int
=
server_args
.
speculative_num_draft_tokens
self
.
draft_token_num
:
int
=
server_args
.
speculative_num_draft_tokens
self
.
branch_length
:
int
=
server_args
.
speculative_
ngram
_branch_length
self
.
branch_length
:
int
=
server_args
.
speculative_
lookahead
_branch_length
self
.
max_match_window_size
:
int
=
(
self
.
max_match_window_size
:
int
=
(
server_args
.
speculative_
ngram
_max_match_window_size
server_args
.
speculative_
lookahead
_max_match_window_size
)
)
self
.
max_batch_size
=
target_worker
.
max_running_requests
self
.
max_batch_size
=
target_worker
.
max_running_requests
...
@@ -48,18 +48,18 @@ class NGRAMWorker:
...
@@ -48,18 +48,18 @@ class NGRAMWorker:
self
.
_init_preallocated_tensors
()
self
.
_init_preallocated_tensors
()
self
.
ngram_cache
=
Ngram
Cache
(
self
.
lookahead_cache
=
Lookahead
Cache
(
min_match_window_size
=
server_args
.
speculative_
ngram
_min_match_window_size
,
min_match_window_size
=
server_args
.
speculative_
lookahead
_min_match_window_size
,
max_match_window_size
=
server_args
.
speculative_
ngram
_max_match_window_size
,
max_match_window_size
=
server_args
.
speculative_
lookahead
_max_match_window_size
,
min_bfs_breadth
=
server_args
.
speculative_
ngram
_min_bfs_breadth
,
min_bfs_breadth
=
server_args
.
speculative_
lookahead
_min_bfs_breadth
,
max_bfs_breadth
=
server_args
.
speculative_
ngram
_max_bfs_breadth
,
max_bfs_breadth
=
server_args
.
speculative_
lookahead
_max_bfs_breadth
,
capacity
=
server_args
.
speculative_
ngram
_capacity
,
capacity
=
server_args
.
speculative_
lookahead
_capacity
,
branch_length
=
server_args
.
speculative_
ngram
_branch_length
,
branch_length
=
server_args
.
speculative_
lookahead
_branch_length
,
draft_token_num
=
server_args
.
speculative_num_draft_tokens
,
draft_token_num
=
server_args
.
speculative_num_draft_tokens
,
)
)
def
clear_cache_pool
(
self
):
def
clear_cache_pool
(
self
):
self
.
ngram
_cache
.
reset
()
self
.
lookahead
_cache
.
reset
()
def
_efficient_concat_last_n
(
self
,
seq1
:
List
[
int
],
seq2
:
List
[
int
],
n
:
int
):
def
_efficient_concat_last_n
(
self
,
seq1
:
List
[
int
],
seq2
:
List
[
int
],
n
:
int
):
seq2_len
=
len
(
seq2
)
seq2_len
=
len
(
seq2
)
...
@@ -124,14 +124,14 @@ class NGRAMWorker:
...
@@ -124,14 +124,14 @@ class NGRAMWorker:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
bs
=
batch
.
batch_size
()
bs
=
batch
.
batch_size
()
self
.
ngram
_cache
.
synchronize
()
self
.
lookahead
_cache
.
synchronize
()
batch_tokens
=
[]
batch_tokens
=
[]
for
req
in
batch
.
reqs
:
for
req
in
batch
.
reqs
:
check_token
=
self
.
_efficient_concat_last_n
(
check_token
=
self
.
_efficient_concat_last_n
(
req
.
origin_input_ids
,
req
.
output_ids
,
self
.
max_match_window_size
req
.
origin_input_ids
,
req
.
output_ids
,
self
.
max_match_window_size
)
)
batch_tokens
.
append
(
check_token
)
batch_tokens
.
append
(
check_token
)
req_drafts
,
mask
=
self
.
ngram
_cache
.
batch_get
(
batch_tokens
)
req_drafts
,
mask
=
self
.
lookahead
_cache
.
batch_get
(
batch_tokens
)
total_draft_token_num
=
len
(
req_drafts
)
total_draft_token_num
=
len
(
req_drafts
)
# Check if speculative decoding is needed; here we always enforce it
# Check if speculative decoding is needed; here we always enforce it
...
@@ -184,9 +184,9 @@ class NGRAMWorker:
...
@@ -184,9 +184,9 @@ class NGRAMWorker:
tree_mask
.
append
(
req_mask
.
flatten
())
tree_mask
.
append
(
req_mask
.
flatten
())
tree_mask
=
torch
.
cat
(
tree_mask
,
dim
=
0
)
tree_mask
=
torch
.
cat
(
tree_mask
,
dim
=
0
)
batch
.
spec_algorithm
=
SpeculativeAlgorithm
.
NGRAM
batch
.
spec_algorithm
=
SpeculativeAlgorithm
.
LOOKAHEAD
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
forward_mode
=
ForwardMode
.
TARGET_VERIFY
batch
.
spec_info
=
Ngram
VerifyInput
(
batch
.
spec_info
=
Lookahead
VerifyInput
(
draft_tokens
,
draft_tokens
,
tree_mask
,
tree_mask
,
positions
,
positions
,
...
@@ -197,7 +197,7 @@ class NGRAMWorker:
...
@@ -197,7 +197,7 @@ class NGRAMWorker:
)
)
batch
.
spec_info
.
prepare_for_verify
(
batch
,
self
.
page_size
)
batch
.
spec_info
.
prepare_for_verify
(
batch
,
self
.
page_size
)
def
_update_
ngram
_cache
(
self
,
batch
:
ScheduleBatch
):
def
_update_
lookahead
_cache
(
self
,
batch
:
ScheduleBatch
):
batch_tokens
=
[]
batch_tokens
=
[]
for
req
in
batch
.
reqs
:
for
req
in
batch
.
reqs
:
# FIXME: Whether to insert 'extend' into the cache or not, after testing,
# FIXME: Whether to insert 'extend' into the cache or not, after testing,
...
@@ -209,7 +209,7 @@ class NGRAMWorker:
...
@@ -209,7 +209,7 @@ class NGRAMWorker:
req
.
origin_input_ids
,
req
.
output_ids
,
self
.
branch_length
req
.
origin_input_ids
,
req
.
output_ids
,
self
.
branch_length
)
)
batch_tokens
.
append
(
put_ids
)
batch_tokens
.
append
(
put_ids
)
self
.
ngram
_cache
.
batch_put
(
batch_tokens
)
self
.
lookahead
_cache
.
batch_put
(
batch_tokens
)
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
):
def
forward_batch_speculative_generation
(
self
,
batch
:
ScheduleBatch
):
self
.
_prepare_for_speculative_decoding
(
batch
)
self
.
_prepare_for_speculative_decoding
(
batch
)
...
@@ -227,7 +227,7 @@ class NGRAMWorker:
...
@@ -227,7 +227,7 @@ class NGRAMWorker:
logits_output
,
next_token_ids
,
num_accepted_tokens
=
verify_input
.
verify
(
logits_output
,
next_token_ids
,
num_accepted_tokens
=
verify_input
.
verify
(
batch
,
logits_output
,
self
.
page_size
batch
,
logits_output
,
self
.
page_size
)
)
self
.
_update_
ngram
_cache
(
batch
)
self
.
_update_
lookahead
_cache
(
batch
)
batch
.
forward_mode
=
ForwardMode
.
DECODE
batch
.
forward_mode
=
ForwardMode
.
DECODE
else
:
else
:
...
...
python/sglang/srt/speculative/spec_info.py
View file @
852a49c5
...
@@ -6,7 +6,7 @@ class SpeculativeAlgorithm(IntEnum):
...
@@ -6,7 +6,7 @@ class SpeculativeAlgorithm(IntEnum):
EAGLE
=
auto
()
EAGLE
=
auto
()
EAGLE3
=
auto
()
EAGLE3
=
auto
()
STANDALONE
=
auto
()
STANDALONE
=
auto
()
NGRAM
=
auto
()
LOOKAHEAD
=
auto
()
def
is_none
(
self
):
def
is_none
(
self
):
return
self
==
SpeculativeAlgorithm
.
NONE
return
self
==
SpeculativeAlgorithm
.
NONE
...
@@ -20,8 +20,8 @@ class SpeculativeAlgorithm(IntEnum):
...
@@ -20,8 +20,8 @@ class SpeculativeAlgorithm(IntEnum):
def
is_standalone
(
self
):
def
is_standalone
(
self
):
return
self
==
SpeculativeAlgorithm
.
STANDALONE
return
self
==
SpeculativeAlgorithm
.
STANDALONE
def
is_
ngram
(
self
):
def
is_
lookahead
(
self
):
return
self
==
SpeculativeAlgorithm
.
NGRAM
return
self
==
SpeculativeAlgorithm
.
LOOKAHEAD
@
staticmethod
@
staticmethod
def
from_string
(
name
:
str
):
def
from_string
(
name
:
str
):
...
@@ -29,7 +29,7 @@ class SpeculativeAlgorithm(IntEnum):
...
@@ -29,7 +29,7 @@ class SpeculativeAlgorithm(IntEnum):
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
"EAGLE3"
:
SpeculativeAlgorithm
.
EAGLE3
,
"EAGLE3"
:
SpeculativeAlgorithm
.
EAGLE3
,
"STANDALONE"
:
SpeculativeAlgorithm
.
STANDALONE
,
"STANDALONE"
:
SpeculativeAlgorithm
.
STANDALONE
,
"
NGRAM
"
:
SpeculativeAlgorithm
.
NGRAM
,
"
LOOKAHEAD
"
:
SpeculativeAlgorithm
.
LOOKAHEAD
,
None
:
SpeculativeAlgorithm
.
NONE
,
None
:
SpeculativeAlgorithm
.
NONE
,
}
}
if
name
is
not
None
:
if
name
is
not
None
:
...
...
python/sglang/srt/two_batch_overlap.py
View file @
852a49c5
...
@@ -31,7 +31,7 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -31,7 +31,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.utils
import
BumpAllocator
,
empty_context
,
get_bool_env_var
,
is_hip
from
sglang.srt.utils
import
BumpAllocator
,
get_bool_env_var
,
is_hip
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.token_dispatcher
import
DispatchOutput
from
sglang.srt.layers.moe.token_dispatcher
import
DispatchOutput
...
...
python/sglang/srt/utils.py
View file @
852a49c5
...
@@ -22,7 +22,6 @@ import ctypes
...
@@ -22,7 +22,6 @@ import ctypes
import
dataclasses
import
dataclasses
import
functools
import
functools
import
importlib
import
importlib
import
inspect
import
io
import
io
import
ipaddress
import
ipaddress
import
itertools
import
itertools
...
@@ -195,7 +194,7 @@ _warned_bool_env_var_keys = set()
...
@@ -195,7 +194,7 @@ _warned_bool_env_var_keys = set()
def
get_bool_env_var
(
name
:
str
,
default
:
str
=
"false"
)
->
bool
:
def
get_bool_env_var
(
name
:
str
,
default
:
str
=
"false"
)
->
bool
:
# FIXME: move your environment variable to sglang.
srt.
environ
# FIXME: move your environment variable to sglang.environ
value
=
os
.
getenv
(
name
,
default
)
value
=
os
.
getenv
(
name
,
default
)
value
=
value
.
lower
()
value
=
value
.
lower
()
...
@@ -213,7 +212,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
...
@@ -213,7 +212,7 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
def
get_int_env_var
(
name
:
str
,
default
:
int
=
0
)
->
int
:
def
get_int_env_var
(
name
:
str
,
default
:
int
=
0
)
->
int
:
# FIXME: move your environment variable to sglang.
srt.
environ
# FIXME: move your environment variable to sglang.environ
value
=
os
.
getenv
(
name
)
value
=
os
.
getenv
(
name
)
if
value
is
None
or
not
value
.
strip
():
if
value
is
None
or
not
value
.
strip
():
return
default
return
default
...
@@ -471,7 +470,7 @@ def is_pin_memory_available() -> bool:
...
@@ -471,7 +470,7 @@ def is_pin_memory_available() -> bool:
class
LayerFn
(
Protocol
):
class
LayerFn
(
Protocol
):
def
__call__
(
self
,
layer_
id
:
int
,
prefix
:
str
)
->
torch
.
nn
.
Module
:
...
def
__call__
(
self
,
id
x
:
int
,
prefix
:
str
)
->
torch
.
nn
.
Module
:
...
def
make_layers
(
def
make_layers
(
...
@@ -482,7 +481,7 @@ def make_layers(
...
@@ -482,7 +481,7 @@ def make_layers(
prefix
:
str
=
""
,
prefix
:
str
=
""
,
return_tuple
:
bool
=
False
,
return_tuple
:
bool
=
False
,
offloader_kwargs
:
Dict
[
str
,
Any
]
=
{},
offloader_kwargs
:
Dict
[
str
,
Any
]
=
{},
)
->
Tuple
[
int
,
int
,
torch
.
nn
.
Module
Lis
t
]:
)
->
Tuple
[
torch
.
nn
.
Module
,
int
,
in
t
]:
"""Make a list of layers with the given layer function"""
"""Make a list of layers with the given layer function"""
# circula imports
# circula imports
from
sglang.srt.distributed
import
get_pp_indices
from
sglang.srt.distributed
import
get_pp_indices
...
@@ -518,50 +517,6 @@ def make_layers(
...
@@ -518,50 +517,6 @@ def make_layers(
return
modules
,
start_layer
,
end_layer
return
modules
,
start_layer
,
end_layer
cmo_stream
=
None
def
get_cmo_stream
():
"""
Cache Management Operation(CMO).
Launch a new stream to prefetch the weight of matmul when running other
AIV or communication kernels, aiming to overlap the memory access time.
"""
global
cmo_stream
if
cmo_stream
is
None
:
cmo_stream
=
torch
.
get_device_module
().
Stream
()
return
cmo_stream
def
prepare_weight_cache
(
handle
,
cache
):
import
torch_npu
NPU_PREFETCH_MAX_SIZE_BYTES
=
(
1000000000
# 1GB, a large value to prefetch entire weight
)
stream
=
get_cmo_stream
()
stream
.
wait_stream
(
torch
.
npu
.
current_stream
())
with
torch
.
npu
.
stream
(
stream
):
if
isinstance
(
cache
,
list
):
for
weight
in
cache
:
torch_npu
.
npu_prefetch
(
weight
,
handle
,
NPU_PREFETCH_MAX_SIZE_BYTES
,
)
else
:
torch_npu
.
npu_prefetch
(
cache
,
handle
,
NPU_PREFETCH_MAX_SIZE_BYTES
,
)
def
wait_cmo_stream
():
cur_stream
=
torch
.
get_device_module
().
current_stream
()
cur_stream
.
wait_stream
(
get_cmo_stream
())
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
"""Set the random seed for all libraries."""
"""Set the random seed for all libraries."""
random
.
seed
(
seed
)
random
.
seed
(
seed
)
...
@@ -2054,6 +2009,13 @@ def set_uvicorn_logging_configs():
...
@@ -2054,6 +2009,13 @@ def set_uvicorn_logging_configs():
LOGGING_CONFIG
[
"formatters"
][
"access"
][
"datefmt"
]
=
"%Y-%m-%d %H:%M:%S"
LOGGING_CONFIG
[
"formatters"
][
"access"
][
"datefmt"
]
=
"%Y-%m-%d %H:%M:%S"
def
get_ip
()
->
Optional
[
str
]:
host_ip
=
os
.
getenv
(
"SGLANG_HOST_IP"
,
""
)
or
os
.
getenv
(
"HOST_IP"
,
""
)
if
host_ip
:
return
host_ip
return
None
def
get_open_port
()
->
int
:
def
get_open_port
()
->
int
:
port
=
os
.
getenv
(
"SGLANG_PORT"
)
port
=
os
.
getenv
(
"SGLANG_PORT"
)
if
port
is
not
None
:
if
port
is
not
None
:
...
@@ -2393,10 +2355,8 @@ def get_local_ip_auto(fallback: str = None) -> str:
...
@@ -2393,10 +2355,8 @@ def get_local_ip_auto(fallback: str = None) -> str:
2. Network interface enumeration via get_local_ip_by_nic()
2. Network interface enumeration via get_local_ip_by_nic()
3. Remote connection method via get_local_ip_by_remote()
3. Remote connection method via get_local_ip_by_remote()
"""
"""
# Try environment variable
if
ip
:
=
get_ip
():
host_ip
=
os
.
getenv
(
"SGLANG_HOST_IP"
,
""
)
or
os
.
getenv
(
"HOST_IP"
,
""
)
return
ip
if
host_ip
:
return
host_ip
logger
.
debug
(
"get_ip failed"
)
logger
.
debug
(
"get_ip failed"
)
# Fallback
# Fallback
if
ip
:
=
get_local_ip_by_nic
():
if
ip
:
=
get_local_ip_by_nic
():
...
@@ -2460,7 +2420,7 @@ class BumpAllocator:
...
@@ -2460,7 +2420,7 @@ class BumpAllocator:
def
log_info_on_rank0
(
logger
,
msg
):
def
log_info_on_rank0
(
logger
,
msg
):
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
if
torch
.
distributed
.
is_initialized
()
and
get_tensor_model_parallel_rank
()
==
0
:
if
get_tensor_model_parallel_rank
()
==
0
:
logger
.
info
(
msg
)
logger
.
info
(
msg
)
...
@@ -3220,120 +3180,3 @@ def get_extend_input_len_swa_limit(
...
@@ -3220,120 +3180,3 @@ def get_extend_input_len_swa_limit(
# and we can only free out-of-sliding-window kv indices after each prefill.
# and we can only free out-of-sliding-window kv indices after each prefill.
# 3. page_size is because we want to have 1 token extra for generated tokens.
# 3. page_size is because we want to have 1 token extra for generated tokens.
return
page_size
+
2
*
max
(
sliding_window_size
,
chunked_prefill_size
)
return
page_size
+
2
*
max
(
sliding_window_size
,
chunked_prefill_size
)
class
CachedKernel
:
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
This wrapper caches compiled Triton kernels based on keys extracted by a
user-provided key function to avoid redundant compilations.
"""
def
__init__
(
self
,
fn
,
key_fn
=
None
):
self
.
fn
=
fn
assert
isinstance
(
fn
,
triton
.
runtime
.
jit
.
JITFunction
)
original_fn
=
fn
.
fn
self
.
signature
=
inspect
.
signature
(
original_fn
)
self
.
param_names
=
tuple
(
self
.
signature
.
parameters
.
keys
())
self
.
num_args
=
len
(
self
.
param_names
)
# Check that no parameters have default values
for
name
,
param
in
self
.
signature
.
parameters
.
items
():
assert
(
param
.
default
is
inspect
.
Parameter
.
empty
),
f
"Parameter '
{
name
}
' has a default value. Default parameters are not supported in cached kernels."
functools
.
update_wrapper
(
self
,
original_fn
)
self
.
kernel_cache
=
{}
# Store the key function
self
.
key_fn
=
key_fn
def
__getitem__
(
self
,
grid
):
"""
Index with grid to get a launcher function.
Returns a launcher that will handle caching based on the key function.
"""
assert
(
isinstance
(
grid
,
tuple
)
and
len
(
grid
)
<=
3
),
"Grid must be a tuple with at most 3 dimensions."
# Normalize grid once
if
len
(
grid
)
<
3
:
grid
=
grid
+
(
1
,)
*
(
3
-
len
(
grid
))
def
launcher
(
*
args
,
**
kwargs
):
cache_key
=
self
.
key_fn
(
args
,
kwargs
)
cached_kernel
=
self
.
kernel_cache
.
get
(
cache_key
)
if
cached_kernel
is
None
:
# First time: compile and cache the kernel
cached_kernel
=
self
.
fn
[
grid
](
*
args
,
**
kwargs
)
self
.
kernel_cache
[
cache_key
]
=
cached_kernel
return
cached_kernel
else
:
# Use cached kernel
all_args
=
self
.
_build_args
(
args
,
kwargs
)
cached_kernel
[
grid
](
*
all_args
)
return
cached_kernel
return
launcher
def
_build_args
(
self
,
args
,
kwargs
):
"""
Build the complete argument list for kernel invocation.
"""
complete_args
=
list
(
args
)
for
i
in
range
(
len
(
args
),
self
.
num_args
):
name
=
self
.
param_names
[
i
]
value
=
kwargs
.
get
(
name
,
inspect
.
Parameter
.
empty
)
if
value
is
not
inspect
.
Parameter
.
empty
:
complete_args
.
append
(
value
)
else
:
raise
ValueError
(
f
"Missing argument:
{
name
}
"
)
return
complete_args
def
_clear_cache
(
self
):
"""
Clear the kernel cache for testing purposes.
"""
self
.
kernel_cache
.
clear
()
def
cached_triton_kernel
(
key_fn
=
None
):
"""
Decorator that enables key-based caching for Triton kernels using a key function.
It essentially bypasses Triton's built-in caching mechanism, allowing users to
define their own caching strategy based on kernel parameters. This helps reduce
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
is simple.
Usage:
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
@triton.jit
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
...
# Invoke normally
my_kernel[grid](x, y, BLOCK_SIZE=1024)
Args:
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
The key can be a single value or a tuple of values.
Returns:
A decorator that wraps the kernel with caching functionality.
Note: Kernels with default parameter values are not supported and will raise an assertion error.
"""
def
decorator
(
fn
):
return
CachedKernel
(
fn
,
key_fn
)
return
decorator
python/sglang/test/run_eval.py
View file @
852a49c5
...
@@ -60,11 +60,6 @@ def run_eval(args):
...
@@ -60,11 +60,6 @@ def run_eval(args):
from
sglang.test.simple_eval_humaneval
import
HumanEval
from
sglang.test.simple_eval_humaneval
import
HumanEval
eval_obj
=
HumanEval
(
args
.
num_examples
,
args
.
num_threads
)
eval_obj
=
HumanEval
(
args
.
num_examples
,
args
.
num_threads
)
elif
args
.
eval_name
==
"mmmu"
:
# VLM MMMU evaluation with fixed 100 examples by default
from
sglang.test.simple_eval_mmmu_vlm
import
MMMUVLMEval
eval_obj
=
MMMUVLMEval
(
args
.
num_examples
,
args
.
num_threads
)
else
:
else
:
raise
ValueError
(
f
"Invalid eval name:
{
args
.
eval_name
}
"
)
raise
ValueError
(
f
"Invalid eval name:
{
args
.
eval_name
}
"
)
...
@@ -99,8 +94,6 @@ def run_eval(args):
...
@@ -99,8 +94,6 @@ def run_eval(args):
print
(
f
"Total latency:
{
latency
:.
3
f
}
s"
)
print
(
f
"Total latency:
{
latency
:.
3
f
}
s"
)
print
(
f
"Score:
{
metrics
[
'score'
]:.
3
f
}
"
)
print
(
f
"Score:
{
metrics
[
'score'
]:.
3
f
}
"
)
if
getattr
(
args
,
"return_latency"
,
False
):
return
metrics
,
latency
return
metrics
return
metrics
...
...
python/sglang/test/simple_eval_common.py
View file @
852a49c5
...
@@ -136,7 +136,7 @@ class ChatCompletionSampler(SamplerBase):
...
@@ -136,7 +136,7 @@ class ChatCompletionSampler(SamplerBase):
self
.
_pack_message
(
"system"
,
self
.
system_message
)
self
.
_pack_message
(
"system"
,
self
.
system_message
)
]
+
message_list
]
+
message_list
trial
=
0
trial
=
0
while
trial
<
6
:
# 126 seconds in total
while
True
:
try
:
try
:
response
=
self
.
client
.
chat
.
completions
.
create
(
response
=
self
.
client
.
chat
.
completions
.
create
(
model
=
self
.
model
,
model
=
self
.
model
,
...
...
python/sglang/test/simple_eval_mmmu_vlm.py
deleted
100644 → 0
View file @
8f7453e3
"""
MMMU evaluation for VLMs using the run_eval simple-evals interface.
"""
from
__future__
import
annotations
import
base64
import
io
from
typing
import
List
,
Optional
,
Tuple
from
datasets
import
concatenate_datasets
,
load_dataset
from
PIL
import
Image
from
sglang.test
import
simple_eval_common
as
common
from
sglang.test.simple_eval_common
import
(
HTML_JINJA
,
Eval
,
EvalResult
,
SamplerBase
,
SingleEvalResult
,
map_with_progress
,
)
class
MMMUVLMEval
(
Eval
):
DOMAIN_CAT2SUB_CAT
=
{
"Art and Design"
:
[
"Art"
,
"Art_Theory"
,
"Design"
,
"Music"
],
"Business"
:
[
"Accounting"
,
"Economics"
,
"Finance"
,
"Manage"
,
"Marketing"
],
"Science"
:
[
"Biology"
,
"Chemistry"
,
"Geography"
,
"Math"
,
"Physics"
],
"Health and Medicine"
:
[
"Basic_Medical_Science"
,
"Clinical_Medicine"
,
"Diagnostics_and_Laboratory_Medicine"
,
"Pharmacy"
,
"Public_Health"
,
],
"Humanities and Social Science"
:
[
"History"
,
"Literature"
,
"Sociology"
,
"Psychology"
,
],
"Tech and Engineering"
:
[
"Agriculture"
,
"Architecture_and_Engineering"
,
"Computer_Science"
,
"Electronics"
,
"Energy_and_Power"
,
"Materials"
,
"Mechanical_Engineering"
,
],
}
def
__init__
(
self
,
num_examples
:
Optional
[
int
]
=
100
,
num_threads
:
int
=
32
,
seed
:
int
=
42
):
"""Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
self
.
num_examples
=
num_examples
self
.
num_threads
=
num_threads
self
.
seed
=
seed
# Prepare samples deterministically across all MMMU subjects (validation split)
self
.
samples
=
self
.
_prepare_mmmu_samples
(
self
.
num_examples
)
@
staticmethod
def
_to_data_uri
(
image
:
Image
.
Image
)
->
str
:
if
image
.
mode
==
"RGBA"
:
image
=
image
.
convert
(
"RGB"
)
buf
=
io
.
BytesIO
()
image
.
save
(
buf
,
format
=
"PNG"
)
b64
=
base64
.
b64encode
(
buf
.
getvalue
()).
decode
(
"utf-8"
)
return
f
"data:image/png;base64,
{
b64
}
"
@
staticmethod
def
_build_mc_mapping
(
options
:
List
[
str
])
->
Tuple
[
dict
,
List
[
str
]]:
index2ans
=
{}
all_choices
=
[]
ch
=
ord
(
"A"
)
for
opt
in
options
:
letter
=
chr
(
ch
)
index2ans
[
letter
]
=
opt
all_choices
.
append
(
letter
)
ch
+=
1
return
index2ans
,
all_choices
def
_prepare_mmmu_samples
(
self
,
k
:
int
)
->
List
[
dict
]:
# Subjects and domains copied from MMMU data_utils to categorize results
subjects
:
List
[
str
]
=
[]
for
subs
in
self
.
DOMAIN_CAT2SUB_CAT
.
values
():
subjects
.
extend
(
subs
)
# Load validation split of each subject
datasets
=
[]
for
subj
in
subjects
:
try
:
d
=
load_dataset
(
"MMMU/MMMU"
,
subj
,
split
=
"validation"
)
# attach subject info via transform
d
=
d
.
add_column
(
"__subject__"
,
[
subj
]
*
len
(
d
))
datasets
.
append
(
d
)
except
Exception
:
continue
if
not
datasets
:
raise
RuntimeError
(
"Failed to load MMMU datasets"
)
merged
=
concatenate_datasets
(
datasets
)
# Deterministic selection: sort by id (fallback to subject+index)
def
_key
(
idx
):
ex
=
merged
[
idx
]
return
str
(
ex
.
get
(
"id"
,
f
"
{
ex
[
'__subject__'
]
}
:
{
idx
}
"
))
order
=
sorted
(
range
(
len
(
merged
)),
key
=
_key
)
picked_indices
=
order
[:
k
]
samples
:
List
[
dict
]
=
[]
for
idx
in
picked_indices
:
ex
=
merged
[
idx
]
subject
=
ex
[
"__subject__"
]
image
=
ex
.
get
(
"image_1"
)
if
image
is
None
or
not
hasattr
(
image
,
"convert"
):
continue
data_uri
=
self
.
_to_data_uri
(
image
)
question
=
ex
.
get
(
"question"
,
""
)
answer
=
ex
.
get
(
"answer"
)
raw_options
=
ex
.
get
(
"options"
)
question_type
=
"open"
index2ans
=
None
all_choices
=
None
options
=
None
if
raw_options
:
try
:
options
=
(
raw_options
if
isinstance
(
raw_options
,
list
)
else
list
(
eval
(
raw_options
))
)
if
isinstance
(
options
,
list
)
and
len
(
options
)
>
0
:
index2ans
,
all_choices
=
self
.
_build_mc_mapping
(
options
)
question_type
=
"multiple-choice"
except
Exception
:
options
=
None
# Build final textual prompt; include choices if MC
prompt_text
=
f
"Question:
{
question
}
\n\n
"
if
options
:
letters
=
[
chr
(
ord
(
"A"
)
+
i
)
for
i
in
range
(
len
(
options
))]
for
letter
,
opt
in
zip
(
letters
,
options
):
prompt_text
+=
f
"
{
letter
}
)
{
opt
}
\n
"
prompt_text
+=
"
\n
Answer: "
samples
.
append
(
{
"id"
:
ex
.
get
(
"id"
,
f
"
{
subject
}
:
{
idx
}
"
),
"final_input_prompt"
:
prompt_text
,
"image_data"
:
data_uri
,
"answer"
:
answer
,
"question_type"
:
question_type
,
"index2ans"
:
index2ans
,
"all_choices"
:
all_choices
,
"category"
:
subject
,
}
)
return
samples
@
staticmethod
def
_split_prompt_for_image
(
prompt
:
str
)
->
tuple
[
str
,
str
]:
"""Split a prompt containing an inline image tag into prefix and suffix.
If no tag is present, treat the whole prompt as prefix and empty suffix.
"""
if
"<"
in
prompt
and
">"
in
prompt
:
prefix
=
prompt
.
split
(
"<"
)[
0
]
suffix
=
prompt
.
split
(
">"
,
1
)[
1
]
return
prefix
,
suffix
return
prompt
,
""
@
staticmethod
def
build_chat_messages_from_prompt
(
prompt
:
str
,
image_data
)
->
List
:
"""Split a prompt containing an inline image tag into prefix and suffix.
If no tag is present, treat the whole prompt as prefix and empty suffix.
"""
# Build a vision+text message for OpenAI-compatible API
prefix
,
suffix
=
MMMUVLMEval
.
_split_prompt_for_image
(
prompt
)
content
:
List
[
dict
]
=
[]
if
prefix
:
content
.
append
({
"type"
:
"text"
,
"text"
:
prefix
})
content
.
append
({
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
image_data
}})
if
suffix
:
content
.
append
({
"type"
:
"text"
,
"text"
:
suffix
})
prompt_messages
=
[{
"role"
:
"user"
,
"content"
:
content
}]
return
prompt_messages
def
__call__
(
self
,
sampler
:
SamplerBase
)
->
EvalResult
:
def
fn
(
sample
:
dict
):
prompt
=
sample
[
"final_input_prompt"
]
image_data
=
sample
[
"image_data"
]
prompt_messages
=
MMMUVLMEval
.
build_chat_messages_from_prompt
(
prompt
,
image_data
)
# Sample
response_text
=
sampler
(
prompt_messages
)
# Parse and score
gold
=
sample
[
"answer"
]
if
(
sample
[
"question_type"
]
==
"multiple-choice"
and
sample
[
"all_choices"
]
and
sample
[
"index2ans"
]
):
pred
=
_parse_multi_choice_response
(
response_text
,
sample
[
"all_choices"
],
sample
[
"index2ans"
]
)
score
=
1.0
if
(
gold
is
not
None
and
pred
==
gold
)
else
0.0
extracted_answer
=
pred
else
:
parsed_list
=
_parse_open_response
(
response_text
)
score
=
(
1.0
if
(
gold
is
not
None
and
_eval_open
(
gold
,
parsed_list
))
else
0.0
)
extracted_answer
=
", "
.
join
(
map
(
str
,
parsed_list
))
html_rendered
=
common
.
jinja_env
.
from_string
(
HTML_JINJA
).
render
(
prompt_messages
=
prompt_messages
,
next_message
=
dict
(
content
=
response_text
,
role
=
"assistant"
),
score
=
score
,
correct_answer
=
gold
,
extracted_answer
=
extracted_answer
,
)
convo
=
prompt_messages
+
[
dict
(
content
=
response_text
,
role
=
"assistant"
)]
return
SingleEvalResult
(
html
=
html_rendered
,
score
=
score
,
metrics
=
{
"__category__"
:
sample
[
"category"
]},
convo
=
convo
,
)
results
=
map_with_progress
(
fn
,
self
.
samples
,
self
.
num_threads
)
# Build category table and overall accuracy
# Gather per-sample correctness and category
per_cat_total
:
dict
[
str
,
int
]
=
{}
per_cat_correct
:
dict
[
str
,
int
]
=
{}
htmls
=
[]
convos
=
[]
scores
:
List
[
float
]
=
[]
for
r
in
results
:
# __category__ stored under metrics
cat
=
r
.
metrics
.
get
(
"__category__"
)
if
r
.
metrics
else
None
if
cat
is
None
:
cat
=
"Unknown"
per_cat_total
[
cat
]
=
per_cat_total
.
get
(
cat
,
0
)
+
1
if
r
.
score
:
per_cat_correct
[
cat
]
=
per_cat_correct
.
get
(
cat
,
0
)
+
1
htmls
.
append
(
r
.
html
)
convos
.
append
(
r
.
convo
)
if
r
.
score
is
not
None
:
scores
.
append
(
r
.
score
)
evaluation_result
=
{}
for
cat
,
tot
in
per_cat_total
.
items
():
corr
=
per_cat_correct
.
get
(
cat
,
0
)
acc
=
(
corr
/
tot
)
if
tot
>
0
else
0.0
evaluation_result
[
cat
]
=
{
"acc"
:
round
(
acc
,
3
),
"num_example"
:
tot
}
printable_results
=
{}
# Domains first
for
domain
,
cats
in
self
.
DOMAIN_CAT2SUB_CAT
.
items
():
acc_sum
=
0.0
num_sum
=
0
for
cat
in
cats
:
if
cat
in
evaluation_result
:
acc_sum
+=
(
evaluation_result
[
cat
][
"acc"
]
*
evaluation_result
[
cat
][
"num_example"
]
)
num_sum
+=
evaluation_result
[
cat
][
"num_example"
]
if
num_sum
>
0
:
printable_results
[
f
"Overall-
{
domain
}
"
]
=
{
"num"
:
num_sum
,
"acc"
:
round
(
acc_sum
/
num_sum
,
3
),
}
# add each sub-category row if present
for
cat
in
cats
:
if
cat
in
evaluation_result
:
printable_results
[
cat
]
=
{
"num"
:
evaluation_result
[
cat
][
"num_example"
],
"acc"
:
evaluation_result
[
cat
][
"acc"
],
}
# Overall
total_num
=
sum
(
v
[
"num_example"
]
for
v
in
evaluation_result
.
values
())
overall_acc
=
(
sum
(
v
[
"acc"
]
*
v
[
"num_example"
]
for
v
in
evaluation_result
.
values
())
/
total_num
if
total_num
>
0
else
0.0
)
printable_results
[
"Overall"
]
=
{
"num"
:
total_num
,
"acc"
:
round
(
overall_acc
,
3
)}
# Build EvalResult
return
EvalResult
(
score
=
overall_acc
,
metrics
=
printable_results
,
htmls
=
htmls
,
convos
=
convos
)
def
_parse_multi_choice_response
(
response
:
str
,
all_choices
:
List
[
str
],
index2ans
:
dict
)
->
str
:
# loosely adapted from benchmark mmmu eval
for
char
in
[
","
,
"."
,
"!"
,
"?"
,
";"
,
":"
,
"'"
]:
response
=
response
.
strip
(
char
)
response
=
" "
+
response
+
" "
# Prefer explicit letter with bracket e.g. (A)
candidates
:
List
[
str
]
=
[]
for
choice
in
all_choices
:
if
f
"(
{
choice
}
)"
in
response
:
candidates
.
append
(
choice
)
if
not
candidates
:
for
choice
in
all_choices
:
if
f
"
{
choice
}
"
in
response
:
candidates
.
append
(
choice
)
if
not
candidates
and
len
(
response
.
split
())
>
5
:
# try match by option text
for
idx
,
ans
in
index2ans
.
items
():
if
ans
and
ans
.
lower
()
in
response
.
lower
():
candidates
.
append
(
idx
)
if
not
candidates
:
# fallback to first choice
return
all_choices
[
0
]
if
len
(
candidates
)
==
1
:
return
candidates
[
0
]
# choose the last occurrence
starts
=
[]
for
can
in
candidates
:
pos
=
response
.
rfind
(
f
"(
{
can
}
)"
)
if
pos
==
-
1
:
pos
=
response
.
rfind
(
f
"
{
can
}
"
)
if
pos
==
-
1
and
index2ans
.
get
(
can
):
pos
=
response
.
lower
().
rfind
(
index2ans
[
can
].
lower
())
starts
.
append
(
pos
)
return
candidates
[
int
(
max
(
range
(
len
(
starts
)),
key
=
lambda
i
:
starts
[
i
]))]
def
_check_is_number
(
s
:
str
)
->
bool
:
try
:
float
(
s
.
replace
(
","
,
""
))
return
True
except
Exception
:
return
False
def
_normalize_str
(
s
:
str
):
s
=
s
.
strip
()
if
_check_is_number
(
s
):
s
=
s
.
replace
(
","
,
""
)
try
:
v
=
round
(
float
(
s
),
2
)
return
[
v
]
except
Exception
:
return
[
s
.
lower
()]
return
[
s
.
lower
()]
if
len
(
s
)
>
1
else
[
" "
+
s
,
s
+
" "
]
def
_extract_numbers
(
s
:
str
)
->
List
[
str
]:
import
re
as
_re
pattern_commas
=
r
"-?\b\d{1,3}(?:,\d{3})+\b"
pattern_scientific
=
r
"-?\d+(?:\.\d+)?[eE][+-]?\d+"
pattern_simple
=
r
"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
return
(
_re
.
findall
(
pattern_commas
,
s
)
+
_re
.
findall
(
pattern_scientific
,
s
)
+
_re
.
findall
(
pattern_simple
,
s
)
)
def
_parse_open_response
(
response
:
str
)
->
List
[
str
]:
import
re
as
_re
def
get_key_subresponses
(
resp
:
str
)
->
List
[
str
]:
resp
=
resp
.
strip
().
strip
(
"."
).
lower
()
subs
=
_re
.
split
(
r
"\.\s(?=[A-Z])|\n"
,
resp
)
indicators
=
[
"could be "
,
"so "
,
"is "
,
"thus "
,
"therefore "
,
"final "
,
"answer "
,
"result "
,
]
keys
=
[]
for
i
,
s
in
enumerate
(
subs
):
cands
=
[
*
indicators
]
if
i
==
len
(
subs
)
-
1
:
cands
.
append
(
"="
)
shortest
=
None
for
ind
in
cands
:
if
ind
in
s
:
part
=
s
.
split
(
ind
)[
-
1
].
strip
()
if
not
shortest
or
len
(
part
)
<
len
(
shortest
):
shortest
=
part
if
shortest
and
shortest
not
in
[
":"
,
","
,
"."
,
"!"
,
"?"
,
";"
,
":"
,
"'"
]:
keys
.
append
(
shortest
)
return
keys
or
[
resp
]
key_resps
=
get_key_subresponses
(
response
)
pred_list
=
key_resps
.
copy
()
for
r
in
key_resps
:
pred_list
.
extend
(
_extract_numbers
(
r
))
out
=
[]
for
x
in
pred_list
:
out
.
extend
(
_normalize_str
(
x
))
# dedup
return
list
(
dict
.
fromkeys
(
out
))
def
_eval_open
(
gold
,
preds
:
List
[
str
])
->
bool
:
if
isinstance
(
gold
,
list
):
norm_answers
=
[]
for
ans
in
gold
:
norm_answers
.
extend
(
_normalize_str
(
ans
))
else
:
norm_answers
=
_normalize_str
(
gold
)
for
p
in
preds
:
if
isinstance
(
p
,
str
):
for
na
in
norm_answers
:
if
isinstance
(
na
,
str
)
and
na
in
p
:
return
True
else
:
if
p
in
norm_answers
:
return
True
return
False
python/sglang/test/test_block_fp8.py
View file @
852a49c5
...
@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
...
@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
w_s
,
w_s
,
)
)
from
deep_gemm
import
fp8_
m_grouped_gemm_nt_masked
from
deep_gemm
import
m_grouped_gemm_
fp8_fp8_bf16_
nt_masked
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
ref_out
=
torch_w8a8_block_fp8_bmm
(
a
,
a_s
,
w
,
w_s
,
block_size
,
dtype
)
ref_out
=
torch_w8a8_block_fp8_bmm
(
a
,
a_s
,
w
,
w_s
,
block_size
,
dtype
)
fp8_
m_grouped_gemm_nt_masked
(
lhs
,
rhs
,
oe
,
masked_m
,
expected_m
)
m_grouped_gemm_
fp8_fp8_bf16_
nt_masked
(
lhs
,
rhs
,
oe
,
masked_m
,
expected_m
)
out
=
oe
[:,
:
M
,
:]
out
=
oe
[:,
:
M
,
:]
self
.
assertTrue
(
self
.
assertTrue
(
...
...
python/sglang/test/test_deterministic.py
View file @
852a49c5
...
@@ -19,7 +19,7 @@ from sglang.profiler import run_profile
...
@@ -19,7 +19,7 @@ from sglang.profiler import run_profile
PROMPT_1
=
"Tell me about Richard Feynman: "
PROMPT_1
=
"Tell me about Richard Feynman: "
PROMPT_2
=
"Generate 1000 random numbers. Go directly into it, don't say Sure and don't say here are numbers. Just start with a number."
PROMPT_2
=
"Generate 1000 random numbers. Go directly into it, don't say Sure and don't say here are numbers. Just start with a number."
dirpath
=
os
.
path
.
dirname
(
__file__
)
dirpath
=
os
.
path
.
dirname
(
__file__
)
with
open
(
os
.
path
.
join
(
dirpath
,
"
long_prompt.txt"
)
,
"r"
)
as
f
:
with
open
(
"python/sglang/test/
long_prompt.txt"
,
"r"
)
as
f
:
LONG_PROMPT
=
f
.
read
()
LONG_PROMPT
=
f
.
read
()
...
...
python/sglang/test/test_utils.py
View file @
852a49c5
...
@@ -14,12 +14,10 @@ import time
...
@@ -14,12 +14,10 @@ import time
import
unittest
import
unittest
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
datetime
import
datetime
from
functools
import
partial
from
functools
import
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Any
,
Awaitable
,
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Awaitable
,
Callable
,
List
,
Optional
,
Tuple
from
urllib.parse
import
quote
import
aiohttp
import
aiohttp
import
numpy
as
np
import
numpy
as
np
...
@@ -82,7 +80,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
...
@@ -82,7 +80,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
"meta-llama/Llama-3.1-8B-Instruct"
"meta-llama/Llama-3.1-8B-Instruct"
)
)
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST
=
"meta-llama/Llama-3.2-1B-Instruct"
DEFAULT_
NGRAM
_SPECULATIVE_TARGET_MODEL_FOR_TEST
=
"Qwen/Qwen2.5-Coder-7B-Instruct"
DEFAULT_
LOOKAHEAD
_SPECULATIVE_TARGET_MODEL_FOR_TEST
=
"Qwen/Qwen2.5-Coder-7B-Instruct"
# Other use cases
# Other use cases
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
=
(
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
=
(
...
@@ -1469,146 +1467,3 @@ def dump_bench_raw_result(
...
@@ -1469,146 +1467,3 @@ def dump_bench_raw_result(
def
_ensure_remove_suffix
(
text
:
str
,
suffix
:
str
):
def
_ensure_remove_suffix
(
text
:
str
,
suffix
:
str
):
assert
text
.
endswith
(
suffix
)
assert
text
.
endswith
(
suffix
)
return
text
.
removesuffix
(
suffix
)
return
text
.
removesuffix
(
suffix
)
class
ModelDeploySetup
:
def
__init__
(
self
,
model_path
:
str
,
extra_args
:
List
[
str
]
=
[]):
self
.
model_path
=
model_path
if
"--enable-multimodal"
not
in
extra_args
:
extra_args
.
append
(
"--enable-multimodal"
)
if
"--trust-remote-code"
not
in
extra_args
:
extra_args
.
append
(
"--trust-remote-code"
)
self
.
extra_args
=
extra_args
class
ModelEvalMetrics
:
def
__init__
(
self
,
accuracy
:
float
,
eval_time
:
float
):
self
.
accuracy
=
accuracy
self
.
eval_time
=
eval_time
def
extract_trace_link_from_bench_one_batch_server_output
(
output
:
str
)
->
str
:
match
=
re
.
search
(
r
"\[Profile\]\((.*?)\)"
,
output
)
if
match
:
trace_link
=
match
.
group
(
1
)
return
trace_link
return
None
def
parse_models
(
model_string
:
str
):
return
[
model
.
strip
()
for
model
in
model_string
.
split
(
","
)
if
model
.
strip
()]
def
check_evaluation_test_results
(
results
,
test_name
,
model_accuracy_thresholds
,
model_latency_thresholds
=
None
,
model_count
=
None
,
):
"""
results: list of tuple of (model_path, accuracy, latency)
"""
failed_models
=
[]
if
model_latency_thresholds
is
not
None
:
summary
=
" | model | status | score | score_threshold | latency | latency_threshold |
\n
"
summary
+=
"| ----- | ------ | ----- | --------------- | ------- | ----------------- |
\n
"
else
:
summary
=
" | model | status | score | score_threshold |
\n
"
summary
+=
"| ----- | ------ | ----- | --------------- |
\n
"
results_dict
=
{
res
[
0
]:
(
res
[
1
],
res
[
2
])
for
res
in
results
}
for
model
,
accuracy_threshold
in
sorted
(
model_accuracy_thresholds
.
items
()):
latency_threshold
=
(
model_latency_thresholds
.
get
(
model
)
if
model_latency_thresholds
is
not
None
else
1e9
)
if
model
in
results_dict
:
accuracy
,
latency
=
results_dict
[
model
]
is_success
=
accuracy
>=
accuracy_threshold
and
latency
<=
latency_threshold
status_emoji
=
"✅"
if
is_success
else
"❌"
if
not
is_success
:
if
accuracy
<
accuracy_threshold
:
failed_models
.
append
(
f
"
\n
Score Check Failed:
{
model
}
\n
"
f
"Model
{
model
}
score (
{
accuracy
:.
4
f
}
) is below threshold (
{
accuracy_threshold
:.
4
f
}
)"
)
if
latency
>
latency_threshold
:
failed_models
.
append
(
f
"
\n
Latency Check Failed:
{
model
}
\n
"
f
"Model
{
model
}
latency (
{
latency
:.
4
f
}
) is above threshold (
{
latency_threshold
:.
4
f
}
)"
)
if
model_latency_thresholds
is
not
None
:
line
=
f
"|
{
model
}
|
{
status_emoji
}
|
{
accuracy
}
|
{
accuracy_threshold
}
|
{
latency
}
|
{
latency_threshold
}
\n
"
else
:
line
=
(
f
"|
{
model
}
|
{
status_emoji
}
|
{
accuracy
}
|
{
accuracy_threshold
}
\n
"
)
else
:
status_emoji
=
"❌"
failed_models
.
append
(
f
"Model failed to launch or be evaluated:
{
model
}
"
)
if
model_latency_thresholds
is
not
None
:
line
=
f
"|
{
model
}
|
{
status_emoji
}
| N/A |
{
accuracy_threshold
}
| N/A |
{
latency_threshold
}
\n
"
else
:
line
=
f
"|
{
model
}
|
{
status_emoji
}
| N/A |
{
accuracy_threshold
}
\n
"
summary
+=
line
print
(
summary
)
if
is_in_ci
():
write_github_step_summary
(
f
"##
{
test_name
}
\n
{
summary
}
"
)
if
failed_models
:
print
(
"Some models failed the evaluation."
)
raise
AssertionError
(
"
\n
"
.
join
(
failed_models
))
# Bench knobs for bench_one_batch_server (override by env)
def
_parse_int_list_env
(
name
:
str
,
default_val
:
str
):
val
=
os
.
environ
.
get
(
name
,
default_val
)
return
[
int
(
x
)
for
x
in
val
.
split
(
","
)
if
x
]
# Return filenames
def
find_traces_under_path
(
path
:
str
)
->
List
[
str
]:
results
=
[]
for
_
,
dirs
,
files
in
os
.
walk
(
path
):
for
file
in
files
:
if
file
.
endswith
(
".trace.json.gz"
):
results
.
append
(
f
"
{
file
}
"
)
return
results
def
write_results_to_json
(
model
,
metrics
,
mode
=
"a"
):
result
=
{
"timestamp"
:
datetime
.
now
().
isoformat
(),
"model"
:
model
,
"metrics"
:
metrics
,
"score"
:
metrics
[
"score"
],
}
if
"latency"
in
metrics
:
result
[
"latency"
]
=
(
metrics
.
get
(
"latency"
),)
existing_results
=
[]
if
mode
==
"a"
and
os
.
path
.
exists
(
"results.json"
):
try
:
with
open
(
"results.json"
,
"r"
)
as
f
:
existing_results
=
json
.
load
(
f
)
except
json
.
JSONDecodeError
:
existing_results
=
[]
if
isinstance
(
existing_results
,
list
):
existing_results
.
append
(
result
)
else
:
existing_results
=
[
result
]
with
open
(
"results.json"
,
"w"
)
as
f
:
json
.
dump
(
existing_results
,
f
,
indent
=
2
)
python/sglang/utils.py
View file @
852a49c5
"""Common utilities"""
"""Common utilities"""
import
functools
import
importlib
import
importlib
import
inspect
import
json
import
json
import
logging
import
logging
import
os
import
os
import
random
import
random
import
socket
import
socket
import
ssl
import
subprocess
import
subprocess
import
sys
import
sys
import
time
import
time
...
@@ -22,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union
...
@@ -22,6 +23,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union
import
numpy
as
np
import
numpy
as
np
import
pybase64
import
pybase64
import
requests
import
requests
import
triton
from
IPython.display
import
HTML
,
display
from
IPython.display
import
HTML
,
display
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -156,15 +158,7 @@ def http_request(
...
@@ -156,15 +158,7 @@ def http_request(
data
=
bytes
(
dumps
(
json
),
encoding
=
"utf-8"
)
data
=
bytes
(
dumps
(
json
),
encoding
=
"utf-8"
)
try
:
try
:
if
sys
.
version_info
>=
(
3
,
13
):
resp
=
urllib
.
request
.
urlopen
(
req
,
data
=
data
,
cafile
=
verify
)
# Python 3.13+: Use SSL context (cafile removed)
if
verify
and
isinstance
(
verify
,
str
):
context
=
ssl
.
create_default_context
(
cafile
=
verify
)
else
:
context
=
ssl
.
create_default_context
()
resp
=
urllib
.
request
.
urlopen
(
req
,
data
=
data
,
context
=
context
)
else
:
resp
=
urllib
.
request
.
urlopen
(
req
,
data
=
data
,
cafile
=
verify
)
return
HttpResponse
(
resp
)
return
HttpResponse
(
resp
)
except
urllib
.
error
.
HTTPError
as
e
:
except
urllib
.
error
.
HTTPError
as
e
:
return
HttpResponse
(
e
)
return
HttpResponse
(
e
)
...
@@ -549,3 +543,114 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
...
@@ -549,3 +543,114 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
module_name
,
obj_name
=
qualname
.
rsplit
(
"."
,
1
)
module_name
,
obj_name
=
qualname
.
rsplit
(
"."
,
1
)
module
=
importlib
.
import_module
(
module_name
)
module
=
importlib
.
import_module
(
module_name
)
return
getattr
(
module
,
obj_name
)
return
getattr
(
module
,
obj_name
)
class
CachedKernel
:
"""
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
This wrapper caches compiled Triton kernels based on keys extracted by a
user-provided key function to avoid redundant compilations.
"""
def
__init__
(
self
,
fn
,
key_fn
=
None
):
self
.
fn
=
fn
assert
isinstance
(
fn
,
triton
.
runtime
.
jit
.
JITFunction
)
original_fn
=
fn
.
fn
self
.
signature
=
inspect
.
signature
(
original_fn
)
self
.
param_names
=
tuple
(
self
.
signature
.
parameters
.
keys
())
self
.
num_args
=
len
(
self
.
param_names
)
# Check that no parameters have default values
for
name
,
param
in
self
.
signature
.
parameters
.
items
():
assert
(
param
.
default
is
inspect
.
Parameter
.
empty
),
f
"Parameter '
{
name
}
' has a default value. Default parameters are not supported in cached kernels."
functools
.
update_wrapper
(
self
,
original_fn
)
self
.
kernel_cache
=
{}
# Store the key function
self
.
key_fn
=
key_fn
def
__getitem__
(
self
,
grid
):
"""
Index with grid to get a launcher function.
Returns a launcher that will handle caching based on the key function.
"""
assert
(
isinstance
(
grid
,
tuple
)
and
len
(
grid
)
<=
3
),
"Grid must be a tuple with at most 3 dimensions."
# Normalize grid once
if
len
(
grid
)
<
3
:
grid
=
grid
+
(
1
,)
*
(
3
-
len
(
grid
))
def
launcher
(
*
args
,
**
kwargs
):
cache_key
=
self
.
key_fn
(
args
,
kwargs
)
cached_kernel
=
self
.
kernel_cache
.
get
(
cache_key
)
if
cached_kernel
is
None
:
# First time: compile and cache the kernel
cached_kernel
=
self
.
fn
[
grid
](
*
args
,
**
kwargs
)
self
.
kernel_cache
[
cache_key
]
=
cached_kernel
return
cached_kernel
else
:
# Use cached kernel
all_args
=
self
.
_build_args
(
args
,
kwargs
)
cached_kernel
[
grid
](
*
all_args
)
return
cached_kernel
return
launcher
def
_build_args
(
self
,
args
,
kwargs
):
"""
Build the complete argument list for kernel invocation.
"""
complete_args
=
list
(
args
)
for
i
in
range
(
len
(
args
),
self
.
num_args
):
name
=
self
.
param_names
[
i
]
value
=
kwargs
.
get
(
name
,
inspect
.
Parameter
.
empty
)
if
value
is
not
inspect
.
Parameter
.
empty
:
complete_args
.
append
(
value
)
else
:
raise
ValueError
(
f
"Missing argument:
{
name
}
"
)
return
complete_args
def
cached_triton_kernel
(
key_fn
=
None
):
"""
Decorator that enables key-based caching for Triton kernels using a key function.
It essentially bypasses Triton's built-in caching mechanism, allowing users to
define their own caching strategy based on kernel parameters. This helps reduce
the heavy overheads of Triton kernel launch when the kernel specialization dispatch
is simple.
Usage:
@cached_triton_kernel(key_fn=lambda args, kwargs: kwargs.get('BLOCK_SIZE', 1024))
@triton.jit
def my_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
...
# Invoke normally
my_kernel[grid](x, y, BLOCK_SIZE=1024)
Args:
key_fn: A function that takes (args, kwargs) and returns the cache key(s).
The key can be a single value or a tuple of values.
Returns:
A decorator that wraps the kernel with caching functionality.
Note: Kernels with default parameter values are not supported and will raise an assertion error.
"""
def
decorator
(
fn
):
return
CachedKernel
(
fn
,
key_fn
)
return
decorator
Prev
1
…
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment