Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
61f42b57
Unverified
Commit
61f42b57
authored
Jan 19, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 19, 2025
Browse files
Move sgl.Runtime under sglang/lang (#2990)
parent
e403d237
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
267 additions
and
329 deletions
+267
-329
examples/frontend_language/usage/json_decode.py
examples/frontend_language/usage/json_decode.py
+1
-1
examples/frontend_language/usage/triton/models/character_generation/1/model.py
...guage/usage/triton/models/character_generation/1/model.py
+2
-2
examples/runtime/async_io_api.py
examples/runtime/async_io_api.py
+0
-46
python/sglang/api.py
python/sglang/api.py
+1
-6
python/sglang/bench_offline_throughput.py
python/sglang/bench_offline_throughput.py
+2
-1
python/sglang/lang/backend/runtime_endpoint.py
python/sglang/lang/backend/runtime_endpoint.py
+168
-1
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+0
-25
python/sglang/srt/constrained/__init__.py
python/sglang/srt/constrained/__init__.py
+0
-16
python/sglang/srt/constrained/base_grammar_backend.py
python/sglang/srt/constrained/base_grammar_backend.py
+21
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+57
-52
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+2
-2
python/sglang/srt/server.py
python/sglang/srt/server.py
+0
-160
python/sglang/test/runners.py
python/sglang/test/runners.py
+8
-12
scripts/deprecated/test_jump_forward.py
scripts/deprecated/test_jump_forward.py
+1
-1
test/lang/test_srt_backend.py
test/lang/test_srt_backend.py
+1
-1
test/srt/models/test_qwen_models.py
test/srt/models/test_qwen_models.py
+1
-1
test/srt/models/test_reward_models.py
test/srt/models/test_reward_models.py
+2
-2
No files found.
examples/frontend_language/usage/json_decode.py
View file @
61f42b57
...
@@ -9,7 +9,7 @@ from enum import Enum
...
@@ -9,7 +9,7 @@ from enum import Enum
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang.srt.constrained
import
build_regex_from_object
from
sglang.srt.constrained
.outlines_backend
import
build_regex_from_object
character_regex
=
(
character_regex
=
(
r
"""\{\n"""
r
"""\{\n"""
...
...
examples/frontend_language/usage/triton/models/character_generation/1/model.py
View file @
61f42b57
...
@@ -3,8 +3,8 @@ import triton_python_backend_utils as pb_utils
...
@@ -3,8 +3,8 @@ import triton_python_backend_utils as pb_utils
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang
import
function
,
set_default_backend
from
sglang
import
function
from
sglang.srt.constrained
import
build_regex_from_object
from
sglang.srt.constrained
.outlines_backend
import
build_regex_from_object
sgl
.
set_default_backend
(
sgl
.
RuntimeEndpoint
(
"http://localhost:30000"
))
sgl
.
set_default_backend
(
sgl
.
RuntimeEndpoint
(
"http://localhost:30000"
))
...
...
examples/runtime/async_io_api.py
deleted
100644 → 0
View file @
e403d237
"""
Usage:
python3 async_io.py
"""
import
asyncio
from
sglang
import
Runtime
async
def
generate
(
engine
,
prompt
,
sampling_params
,
):
tokenizer
=
engine
.
get_tokenizer
()
messages
=
[
{
"role"
:
"system"
,
"content"
:
"You will be given question answer tasks."
,
},
{
"role"
:
"user"
,
"content"
:
prompt
},
]
prompt
=
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
False
,
add_generation_prompt
=
True
)
stream
=
engine
.
add_request
(
prompt
,
sampling_params
)
async
for
output
in
stream
:
print
(
output
,
end
=
""
,
flush
=
True
)
print
()
if
__name__
==
"__main__"
:
runtime
=
Runtime
(
model_path
=
"meta-llama/Llama-2-7b-chat-hf"
)
print
(
"--- runtime ready ---
\n
"
)
prompt
=
"Who is Alan Turing?"
sampling_params
=
{
"max_new_tokens"
:
128
}
asyncio
.
run
(
generate
(
runtime
,
prompt
,
sampling_params
))
runtime
.
shutdown
()
python/sglang/api.py
View file @
61f42b57
"""Public APIs of the language."""
"""Public APIs of the language."""
import
os
import
re
import
re
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
...
@@ -33,17 +32,13 @@ def function(
...
@@ -33,17 +32,13 @@ def function(
def
Runtime
(
*
args
,
**
kwargs
):
def
Runtime
(
*
args
,
**
kwargs
):
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
# Avoid importing unnecessary dependency
# Avoid importing unnecessary dependency
from
sglang.
srt.server
import
Runtime
from
sglang.
lang.backend.runtime_endpoint
import
Runtime
return
Runtime
(
*
args
,
**
kwargs
)
return
Runtime
(
*
args
,
**
kwargs
)
def
Engine
(
*
args
,
**
kwargs
):
def
Engine
(
*
args
,
**
kwargs
):
os
.
environ
[
"TF_CPP_MIN_LOG_LEVEL"
]
=
"3"
# Avoid importing unnecessary dependency
# Avoid importing unnecessary dependency
from
sglang.srt.server
import
Engine
from
sglang.srt.server
import
Engine
...
...
python/sglang/bench_offline_throughput.py
View file @
61f42b57
...
@@ -27,7 +27,8 @@ from sglang.bench_serving import (
...
@@ -27,7 +27,8 @@ from sglang.bench_serving import (
sample_random_requests
,
sample_random_requests
,
set_ulimit
,
set_ulimit
,
)
)
from
sglang.srt.server
import
Engine
,
Runtime
from
sglang.lang.backend.runtime_endpoint
import
Runtime
from
sglang.srt.server
import
Engine
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
...
...
python/sglang/lang/backend/runtime_endpoint.py
View file @
61f42b57
import
atexit
import
json
import
json
import
multiprocessing
import
warnings
import
warnings
from
typing
import
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Union
import
aiohttp
import
requests
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.lang.backend.base_backend
import
BaseBackend
from
sglang.lang.backend.base_backend
import
BaseBackend
...
@@ -14,6 +19,9 @@ from sglang.lang.ir import (
...
@@ -14,6 +19,9 @@ from sglang.lang.ir import (
REGEX_STR
,
REGEX_STR
,
SglSamplingParams
,
SglSamplingParams
,
)
)
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
is_port_available
,
kill_process_tree
from
sglang.utils
import
http_request
from
sglang.utils
import
http_request
...
@@ -325,3 +333,162 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -325,3 +333,162 @@ class RuntimeEndpoint(BaseBackend):
def
compute_normalized_prompt_logprobs
(
input_logprobs
):
def
compute_normalized_prompt_logprobs
(
input_logprobs
):
values
=
[
x
[
0
]
for
x
in
input_logprobs
if
x
[
0
]]
values
=
[
x
[
0
]
for
x
in
input_logprobs
if
x
[
0
]]
return
sum
(
values
)
/
len
(
values
)
return
sum
(
values
)
/
len
(
values
)
class
Runtime
:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the commond line interface.
It is mainly used for the frontend language.
You should use the Engine class if you want to do normal offline processing.
"""
def
__init__
(
self
,
log_level
:
str
=
"error"
,
*
args
,
**
kwargs
,
):
"""See the arguments in server_args.py::ServerArgs"""
from
sglang.srt.server
import
launch_server
self
.
server_args
=
ServerArgs
(
*
args
,
log_level
=
log_level
,
**
kwargs
)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit
.
register
(
self
.
shutdown
)
# Pre-allocate ports
for
port
in
range
(
self
.
server_args
.
port
,
40000
):
if
is_port_available
(
port
):
break
self
.
server_args
.
port
=
port
self
.
url
=
self
.
server_args
.
url
()
self
.
generate_url
=
self
.
url
+
"/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self
.
pid
=
None
pipe_reader
,
pipe_writer
=
multiprocessing
.
Pipe
(
duplex
=
False
)
proc
=
multiprocessing
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
pipe_writer
),
)
proc
.
start
()
pipe_writer
.
close
()
self
.
pid
=
proc
.
pid
try
:
init_state
=
pipe_reader
.
recv
()
except
EOFError
:
init_state
=
""
if
init_state
!=
"ready"
:
self
.
shutdown
()
raise
RuntimeError
(
"Initialization failed. Please see the error messages above."
)
self
.
endpoint
=
RuntimeEndpoint
(
self
.
url
)
def
shutdown
(
self
):
if
self
.
pid
is
not
None
:
kill_process_tree
(
self
.
pid
)
self
.
pid
=
None
def
cache_prefix
(
self
,
prefix
:
str
):
self
.
endpoint
.
cache_prefix
(
prefix
)
def
get_tokenizer
(
self
):
return
get_tokenizer
(
self
.
server_args
.
tokenizer_path
,
tokenizer_mode
=
self
.
server_args
.
tokenizer_mode
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
revision
=
self
.
server_args
.
revision
,
)
async
def
async_generate
(
self
,
prompt
:
str
,
sampling_params
:
Optional
[
Dict
]
=
None
,
):
if
self
.
server_args
.
skip_tokenizer_init
:
json_data
=
{
"input_ids"
:
prompt
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
else
:
json_data
=
{
"text"
:
prompt
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
pos
=
0
timeout
=
aiohttp
.
ClientTimeout
(
total
=
3
*
3600
)
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
,
trust_env
=
True
)
as
session
:
async
with
session
.
post
(
self
.
generate_url
,
json
=
json_data
)
as
response
:
async
for
chunk
,
_
in
response
.
content
.
iter_chunks
():
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]
\n\n
"
:
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
if
"text"
in
data
:
cur
=
data
[
"text"
][
pos
:]
if
cur
:
yield
cur
pos
+=
len
(
cur
)
else
:
yield
data
add_request
=
async_generate
def
generate
(
self
,
prompt
:
Union
[
str
,
List
[
str
]],
sampling_params
:
Optional
[
Dict
]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
lora_path
:
Optional
[
List
[
Optional
[
str
]]]
=
None
,
):
json_data
=
{
"text"
:
prompt
,
"sampling_params"
:
sampling_params
,
"return_logprob"
:
return_logprob
,
"logprob_start_len"
:
logprob_start_len
,
"top_logprobs_num"
:
top_logprobs_num
,
"lora_path"
:
lora_path
,
}
assert
not
isinstance
(
lora_path
,
list
)
or
len
(
lora_path
)
==
len
(
prompt
)
response
=
requests
.
post
(
self
.
url
+
"/generate"
,
json
=
json_data
,
)
return
json
.
dumps
(
response
.
json
())
def
encode
(
self
,
prompt
:
Union
[
str
,
List
[
str
],
List
[
Dict
],
List
[
List
[
Dict
]]],
):
json_data
=
{
"text"
:
prompt
}
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
return
json
.
dumps
(
response
.
json
())
async
def
get_server_info
(
self
):
async
with
aiohttp
.
ClientSession
()
as
session
:
async
with
session
.
get
(
f
"
{
self
.
url
}
/get_server_info"
)
as
response
:
if
response
.
status
==
200
:
return
await
response
.
json
()
else
:
error_data
=
await
response
.
json
()
raise
RuntimeError
(
f
"Failed to get server info.
{
error_data
[
'error'
][
'message'
]
}
"
)
def
__del__
(
self
):
self
.
shutdown
()
python/sglang/launch_server_llavavid.py
deleted
100644 → 0
View file @
e403d237
"""Launch the inference server for Llava-video model."""
import
json
import
sys
from
sglang.srt.server
import
launch_server
,
prepare_server_args
if
__name__
==
"__main__"
:
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
model_override_args
=
{}
model_override_args
[
"mm_spatial_pool_stride"
]
=
2
model_override_args
[
"architectures"
]
=
[
"LlavaVidForCausalLM"
]
model_override_args
[
"num_frames"
]
=
16
model_override_args
[
"model_type"
]
=
"llavavid"
if
model_override_args
[
"num_frames"
]
==
32
:
model_override_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"rope_type"
:
"linear"
}
model_override_args
[
"max_sequence_length"
]
=
4096
*
2
model_override_args
[
"tokenizer_model_max_length"
]
=
4096
*
2
model_override_args
[
"model_max_length"
]
=
4096
*
2
if
"34b"
in
server_args
.
model_path
.
lower
():
model_override_args
[
"image_token_index"
]
=
64002
server_args
.
json_model_override_args
=
json
.
dumps
(
model_override_args
)
launch_server
(
server_args
)
python/sglang/srt/constrained/__init__.py
deleted
100644 → 0
View file @
e403d237
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# TODO(lmzheng): make this an optional dependency
from
sglang.srt.constrained.outlines_backend
import
build_regex_from_object
python/sglang/srt/constrained/base_grammar_backend.py
View file @
61f42b57
...
@@ -18,6 +18,8 @@ from dataclasses import dataclass
...
@@ -18,6 +18,8 @@ from dataclasses import dataclass
from
threading
import
Event
,
Lock
from
threading
import
Event
,
Lock
from
typing
import
Any
,
Optional
,
Tuple
from
typing
import
Any
,
Optional
,
Tuple
from
sglang.srt.server_args
import
ServerArgs
@
dataclass
@
dataclass
class
CacheEntry
:
class
CacheEntry
:
...
@@ -69,3 +71,22 @@ class BaseGrammarBackend:
...
@@ -69,3 +71,22 @@ class BaseGrammarBackend:
def
reset
(
self
):
def
reset
(
self
):
with
self
.
cache_lock
:
with
self
.
cache_lock
:
self
.
cache
.
clear
()
self
.
cache
.
clear
()
def
create_grammar_backend
(
server_args
:
ServerArgs
,
tokenizer
,
vocab_size
):
if
server_args
.
grammar_backend
==
"outlines"
:
from
sglang.srt.constrained.outlines_backend
import
OutlinesGrammarBackend
grammar_backend
=
OutlinesGrammarBackend
(
tokenizer
,
whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
allow_jump_forward
=
not
server_args
.
disable_jump_forward
,
)
elif
server_args
.
grammar_backend
==
"xgrammar"
:
from
sglang.srt.constrained.xgrammar_backend
import
XGrammarGrammarBackend
grammar_backend
=
XGrammarGrammarBackend
(
tokenizer
,
vocab_size
=
vocab_size
)
else
:
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
return
grammar_backend
python/sglang/srt/managers/scheduler.py
View file @
61f42b57
...
@@ -34,6 +34,7 @@ import zmq
...
@@ -34,6 +34,7 @@ import zmq
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
create_grammar_backend
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
from
sglang.srt.layers.dp_attention
import
compute_dp_attention_world_info
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
...
@@ -149,9 +150,7 @@ class Scheduler:
...
@@ -149,9 +150,7 @@ class Scheduler:
else
1
else
1
)
)
# Init inter-process communication
# Distributed rank info
context
=
zmq
.
Context
(
2
)
self
.
dp_size
=
server_args
.
dp_size
self
.
dp_size
=
server_args
.
dp_size
self
.
attn_tp_rank
,
self
.
attn_tp_size
,
self
.
dp_rank
=
(
self
.
attn_tp_rank
,
self
.
attn_tp_size
,
self
.
dp_rank
=
(
compute_dp_attention_world_info
(
compute_dp_attention_world_info
(
...
@@ -162,6 +161,8 @@ class Scheduler:
...
@@ -162,6 +161,8 @@ class Scheduler:
)
)
)
)
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
if
self
.
attn_tp_rank
==
0
:
if
self
.
attn_tp_rank
==
0
:
self
.
recv_from_tokenizer
=
get_zmq_socket
(
self
.
recv_from_tokenizer
=
get_zmq_socket
(
context
,
zmq
.
PULL
,
port_args
.
scheduler_input_ipc_name
,
False
context
,
zmq
.
PULL
,
port_args
.
scheduler_input_ipc_name
,
False
...
@@ -243,7 +244,7 @@ class Scheduler:
...
@@ -243,7 +244,7 @@ class Scheduler:
nccl_port
=
port_args
.
nccl_port
,
nccl_port
=
port_args
.
nccl_port
,
)
)
# Launch worker for speculative decoding if need
# Launch
a
worker for speculative decoding if need
ed
if
self
.
spec_algorithm
.
is_eagle
():
if
self
.
spec_algorithm
.
is_eagle
():
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
from
sglang.srt.speculative.eagle_worker
import
EAGLEWorker
...
@@ -316,6 +317,8 @@ class Scheduler:
...
@@ -316,6 +317,8 @@ class Scheduler:
self
.
forward_ct
=
0
self
.
forward_ct
=
0
self
.
forward_ct_decode
=
0
self
.
forward_ct_decode
=
0
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
self
.
spec_num_total_accepted_tokens
=
0
self
.
spec_num_total_forward_ct
=
0
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
stream_interval
=
server_args
.
stream_interval
self
.
stream_interval
=
server_args
.
stream_interval
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
self
.
current_stream
=
torch
.
get_device_module
(
self
.
device
).
current_stream
()
...
@@ -337,28 +340,9 @@ class Scheduler:
...
@@ -337,28 +340,9 @@ class Scheduler:
# Init the grammar backend for constrained generation
# Init the grammar backend for constrained generation
self
.
grammar_queue
:
List
[
Req
]
=
[]
self
.
grammar_queue
:
List
[
Req
]
=
[]
if
not
server_args
.
skip_tokenizer_init
:
if
not
server_args
.
skip_tokenizer_init
:
if
server_args
.
grammar_backend
==
"outlines"
:
self
.
grammar_backend
=
create_grammar_backend
(
from
sglang.srt.constrained.outlines_backend
import
(
server_args
,
self
.
tokenizer
,
self
.
model_config
.
vocab_size
OutlinesGrammarBackend
,
)
)
self
.
grammar_backend
=
OutlinesGrammarBackend
(
self
.
tokenizer
,
whitespace_pattern
=
server_args
.
constrained_json_whitespace_pattern
,
allow_jump_forward
=
not
server_args
.
disable_jump_forward
,
)
elif
server_args
.
grammar_backend
==
"xgrammar"
:
from
sglang.srt.constrained.xgrammar_backend
import
(
XGrammarGrammarBackend
,
)
self
.
grammar_backend
=
XGrammarGrammarBackend
(
self
.
tokenizer
,
vocab_size
=
self
.
model_config
.
vocab_size
)
else
:
raise
ValueError
(
f
"Invalid grammar backend:
{
server_args
.
grammar_backend
}
"
)
else
:
else
:
self
.
grammar_backend
=
None
self
.
grammar_backend
=
None
...
@@ -424,7 +408,8 @@ class Scheduler:
...
@@ -424,7 +408,8 @@ class Scheduler:
},
},
)
)
self
.
_dispatcher
=
TypeBasedDispatcher
(
# Init request dispatcher
self
.
_request_dispatcher
=
TypeBasedDispatcher
(
[
[
(
TokenizedGenerateReqInput
,
self
.
handle_generate_request
),
(
TokenizedGenerateReqInput
,
self
.
handle_generate_request
),
(
TokenizedEmbeddingReqInput
,
self
.
handle_embedding_request
),
(
TokenizedEmbeddingReqInput
,
self
.
handle_embedding_request
),
...
@@ -480,10 +465,6 @@ class Scheduler:
...
@@ -480,10 +465,6 @@ class Scheduler:
self
.
process_input_requests
(
recv_reqs
)
self
.
process_input_requests
(
recv_reqs
)
batch
=
self
.
get_next_batch_to_run
()
batch
=
self
.
get_next_batch_to_run
()
if
self
.
server_args
.
enable_dp_attention
:
# TODO: simplify this
batch
=
self
.
prepare_dp_attn_batch
(
batch
)
self
.
cur_batch
=
batch
self
.
cur_batch
=
batch
if
batch
:
if
batch
:
...
@@ -506,10 +487,6 @@ class Scheduler:
...
@@ -506,10 +487,6 @@ class Scheduler:
self
.
process_input_requests
(
recv_reqs
)
self
.
process_input_requests
(
recv_reqs
)
batch
=
self
.
get_next_batch_to_run
()
batch
=
self
.
get_next_batch_to_run
()
if
self
.
server_args
.
enable_dp_attention
:
# TODO: simplify this
batch
=
self
.
prepare_dp_attn_batch
(
batch
)
self
.
cur_batch
=
batch
self
.
cur_batch
=
batch
if
batch
:
if
batch
:
...
@@ -517,7 +494,7 @@ class Scheduler:
...
@@ -517,7 +494,7 @@ class Scheduler:
result_queue
.
append
((
batch
.
copy
(),
result
))
result_queue
.
append
((
batch
.
copy
(),
result
))
if
self
.
last_batch
is
None
:
if
self
.
last_batch
is
None
:
# Create a dummy first batch to start the pipeline for overlap schedule
r
.
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
# It is now used for triggering the sampling_info_done event.
tmp_batch
=
ScheduleBatch
(
tmp_batch
=
ScheduleBatch
(
reqs
=
None
,
reqs
=
None
,
...
@@ -593,7 +570,7 @@ class Scheduler:
...
@@ -593,7 +570,7 @@ class Scheduler:
def
process_input_requests
(
self
,
recv_reqs
:
List
):
def
process_input_requests
(
self
,
recv_reqs
:
List
):
for
recv_req
in
recv_reqs
:
for
recv_req
in
recv_reqs
:
output
=
self
.
_dispatcher
(
recv_req
)
output
=
self
.
_
request_
dispatcher
(
recv_req
)
if
output
is
not
None
:
if
output
is
not
None
:
self
.
send_to_tokenizer
.
send_pyobj
(
output
)
self
.
send_to_tokenizer
.
send_pyobj
(
output
)
...
@@ -798,15 +775,32 @@ class Scheduler:
...
@@ -798,15 +775,32 @@ class Scheduler:
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
self
.
last_decode_stats_tic
=
time
.
time
()
self
.
last_decode_stats_tic
=
time
.
time
()
num_running_reqs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
num_running_reqs
=
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
else
0
logger
.
info
(
f
"Decode batch. "
f
"#running-req:
{
num_running_reqs
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
if
self
.
spec_algorithm
.
is_none
():
msg
=
(
f
"Decode batch. "
f
"#running-req:
{
num_running_reqs
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"gen throughput (token/s):
{
gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
else
:
accept_length
=
(
self
.
spec_num_total_accepted_tokens
/
self
.
spec_num_total_forward_ct
)
self
.
spec_num_total_accepted_tokens
=
self
.
spec_num_total_forward_ct
=
0
msg
=
(
f
"Decode batch. "
f
"#running-req:
{
num_running_reqs
}
, "
f
"#token:
{
num_used
}
, "
f
"token usage:
{
num_used
/
self
.
max_total_num_tokens
:.
2
f
}
, "
f
"accept len:
{
accept_length
:.
2
f
}
, "
f
"gen throughput (token/s):
{
gen_throughput
:.
2
f
}
, "
f
"#queue-req:
{
len
(
self
.
waiting_queue
)
}
"
)
logger
.
info
(
msg
)
if
self
.
enable_metrics
:
if
self
.
enable_metrics
:
self
.
stats
.
num_running_reqs
=
num_running_reqs
self
.
stats
.
num_running_reqs
=
num_running_reqs
self
.
stats
.
num_used_tokens
=
num_used
self
.
stats
.
num_used_tokens
=
num_used
...
@@ -855,16 +849,23 @@ class Scheduler:
...
@@ -855,16 +849,23 @@ class Scheduler:
else
:
else
:
self
.
running_batch
.
merge_batch
(
self
.
last_batch
)
self
.
running_batch
.
merge_batch
(
self
.
last_batch
)
# Run prefill first if possible
new_batch
=
self
.
get_new_batch_prefill
()
new_batch
=
self
.
get_new_batch_prefill
()
if
new_batch
is
not
None
:
if
new_batch
is
not
None
:
return
new_batch
# Run prefill first if possible
ret
=
new_batch
else
:
# Run decode
if
self
.
running_batch
is
None
:
ret
=
None
else
:
self
.
running_batch
=
self
.
update_running_batch
(
self
.
running_batch
)
ret
=
self
.
running_batch
#
Run decode
#
Handle DP attention
if
self
.
running_batch
is
N
on
e
:
if
self
.
server_args
.
enable_dp_attenti
on
:
ret
urn
None
ret
=
self
.
prepare_dp_attn_batch
(
ret
)
self
.
running_batch
=
self
.
update_running_batch
(
self
.
running_batch
)
return
self
.
running_batch
return
ret
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
def
get_new_batch_prefill
(
self
)
->
Optional
[
ScheduleBatch
]:
# Check if the grammar is ready in the grammar queue
# Check if the grammar is ready in the grammar queue
...
@@ -1053,6 +1054,10 @@ class Scheduler:
...
@@ -1053,6 +1054,10 @@ class Scheduler:
model_worker_batch
,
model_worker_batch
,
num_accepted_tokens
,
num_accepted_tokens
,
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
)
=
self
.
draft_worker
.
forward_batch_speculative_generation
(
batch
)
self
.
spec_num_total_accepted_tokens
+=
(
num_accepted_tokens
+
batch
.
batch_size
()
)
self
.
spec_num_total_forward_ct
+=
batch
.
batch_size
()
self
.
num_generated_tokens
+=
num_accepted_tokens
self
.
num_generated_tokens
+=
num_accepted_tokens
else
:
else
:
assert
False
,
"batch.extend_num_tokens == 0, this is unexpected!"
assert
False
,
"batch.extend_num_tokens == 0, this is unexpected!"
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
61f42b57
...
@@ -224,7 +224,7 @@ class TokenizerManager:
...
@@ -224,7 +224,7 @@ class TokenizerManager:
},
},
)
)
self
.
_dispatcher
=
TypeBasedDispatcher
(
self
.
_
result_
dispatcher
=
TypeBasedDispatcher
(
[
[
(
BatchStrOut
,
self
.
_handle_batch_output
),
(
BatchStrOut
,
self
.
_handle_batch_output
),
(
BatchEmbeddingOut
,
self
.
_handle_batch_output
),
(
BatchEmbeddingOut
,
self
.
_handle_batch_output
),
...
@@ -760,7 +760,7 @@ class TokenizerManager:
...
@@ -760,7 +760,7 @@ class TokenizerManager:
while
True
:
while
True
:
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
recv_obj
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
self
.
_dispatcher
(
recv_obj
)
self
.
_
result_
dispatcher
(
recv_obj
)
def
_handle_batch_output
(
def
_handle_batch_output
(
self
,
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
]
self
,
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
]
...
...
python/sglang/srt/server.py
View file @
61f42b57
...
@@ -45,8 +45,6 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
...
@@ -45,8 +45,6 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
ORJSONResponse
,
Response
,
StreamingResponse
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.data_parallel_controller
import
(
from
sglang.srt.managers.data_parallel_controller
import
(
run_data_parallel_controller_process
,
run_data_parallel_controller_process
,
)
)
...
@@ -90,7 +88,6 @@ from sglang.srt.utils import (
...
@@ -90,7 +88,6 @@ from sglang.srt.utils import (
assert_pkg_version
,
assert_pkg_version
,
configure_logger
,
configure_logger
,
delete_directory
,
delete_directory
,
is_port_available
,
kill_process_tree
,
kill_process_tree
,
maybe_set_triton_cache_manager
,
maybe_set_triton_cache_manager
,
prepare_model_and_tokenizer
,
prepare_model_and_tokenizer
,
...
@@ -960,160 +957,3 @@ class Engine:
...
@@ -960,160 +957,3 @@ class Engine:
obj
=
ResumeMemoryOccupationReqInput
()
obj
=
ResumeMemoryOccupationReqInput
()
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
loop
.
run_until_complete
(
tokenizer_manager
.
resume_memory_occupation
(
obj
,
None
))
loop
.
run_until_complete
(
tokenizer_manager
.
resume_memory_occupation
(
obj
,
None
))
class
Runtime
:
"""
A wrapper for the HTTP server.
This is used for launching the server in a python program without
using the commond line interface.
It is mainly used for the frontend language.
You should use the Engine class above if you want to do normal offline processing.
"""
def
__init__
(
self
,
log_level
:
str
=
"error"
,
*
args
,
**
kwargs
,
):
"""See the arguments in server_args.py::ServerArgs"""
self
.
server_args
=
ServerArgs
(
*
args
,
log_level
=
log_level
,
**
kwargs
)
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
atexit
.
register
(
self
.
shutdown
)
# Pre-allocate ports
for
port
in
range
(
self
.
server_args
.
port
,
40000
):
if
is_port_available
(
port
):
break
self
.
server_args
.
port
=
port
self
.
url
=
self
.
server_args
.
url
()
self
.
generate_url
=
self
.
url
+
"/generate"
# NOTE: We store pid instead of proc to fix some issues during __delete__
self
.
pid
=
None
pipe_reader
,
pipe_writer
=
mp
.
Pipe
(
duplex
=
False
)
proc
=
mp
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
pipe_writer
),
)
proc
.
start
()
pipe_writer
.
close
()
self
.
pid
=
proc
.
pid
try
:
init_state
=
pipe_reader
.
recv
()
except
EOFError
:
init_state
=
""
if
init_state
!=
"ready"
:
self
.
shutdown
()
raise
RuntimeError
(
"Initialization failed. Please see the error messages above."
)
self
.
endpoint
=
RuntimeEndpoint
(
self
.
url
)
def
shutdown
(
self
):
if
self
.
pid
is
not
None
:
kill_process_tree
(
self
.
pid
)
self
.
pid
=
None
def
cache_prefix
(
self
,
prefix
:
str
):
self
.
endpoint
.
cache_prefix
(
prefix
)
def
get_tokenizer
(
self
):
return
get_tokenizer
(
self
.
server_args
.
tokenizer_path
,
tokenizer_mode
=
self
.
server_args
.
tokenizer_mode
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
revision
=
self
.
server_args
.
revision
,
)
async
def
async_generate
(
self
,
prompt
:
str
,
sampling_params
:
Optional
[
Dict
]
=
None
,
):
if
self
.
server_args
.
skip_tokenizer_init
:
json_data
=
{
"input_ids"
:
prompt
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
else
:
json_data
=
{
"text"
:
prompt
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
pos
=
0
timeout
=
aiohttp
.
ClientTimeout
(
total
=
3
*
3600
)
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
,
trust_env
=
True
)
as
session
:
async
with
session
.
post
(
self
.
generate_url
,
json
=
json_data
)
as
response
:
async
for
chunk
,
_
in
response
.
content
.
iter_chunks
():
chunk
=
chunk
.
decode
(
"utf-8"
)
if
chunk
and
chunk
.
startswith
(
"data:"
):
if
chunk
==
"data: [DONE]
\n\n
"
:
break
data
=
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
if
"text"
in
data
:
cur
=
data
[
"text"
][
pos
:]
if
cur
:
yield
cur
pos
+=
len
(
cur
)
else
:
yield
data
add_request
=
async_generate
def
generate
(
self
,
prompt
:
Union
[
str
,
List
[
str
]],
sampling_params
:
Optional
[
Dict
]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
lora_path
:
Optional
[
List
[
Optional
[
str
]]]
=
None
,
):
json_data
=
{
"text"
:
prompt
,
"sampling_params"
:
sampling_params
,
"return_logprob"
:
return_logprob
,
"logprob_start_len"
:
logprob_start_len
,
"top_logprobs_num"
:
top_logprobs_num
,
"lora_path"
:
lora_path
,
}
assert
not
isinstance
(
lora_path
,
list
)
or
len
(
lora_path
)
==
len
(
prompt
)
response
=
requests
.
post
(
self
.
url
+
"/generate"
,
json
=
json_data
,
)
return
json
.
dumps
(
response
.
json
())
def
encode
(
self
,
prompt
:
Union
[
str
,
List
[
str
],
List
[
Dict
],
List
[
List
[
Dict
]]],
):
json_data
=
{
"text"
:
prompt
}
response
=
requests
.
post
(
self
.
url
+
"/encode"
,
json
=
json_data
)
return
json
.
dumps
(
response
.
json
())
async
def
get_server_info
(
self
):
async
with
aiohttp
.
ClientSession
()
as
session
:
async
with
session
.
get
(
f
"
{
self
.
url
}
/get_server_info"
)
as
response
:
if
response
.
status
==
200
:
return
await
response
.
json
()
else
:
error_data
=
await
response
.
json
()
raise
RuntimeError
(
f
"Failed to get server info.
{
error_data
[
'error'
][
'message'
]
}
"
)
def
__del__
(
self
):
self
.
shutdown
()
python/sglang/test/runners.py
View file @
61f42b57
...
@@ -23,7 +23,7 @@ import torch.nn.functional as F
...
@@ -23,7 +23,7 @@ import torch.nn.functional as F
from
transformers
import
AutoModelForCausalLM
from
transformers
import
AutoModelForCausalLM
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.server
import
Runtim
e
from
sglang.srt.server
import
Engin
e
from
sglang.test.test_utils
import
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
from
sglang.test.test_utils
import
DEFAULT_PORT_FOR_SRT_TEST_RUNNER
DEFAULT_PROMPTS
=
[
DEFAULT_PROMPTS
=
[
...
@@ -278,7 +278,7 @@ class SRTRunner:
...
@@ -278,7 +278,7 @@ class SRTRunner:
):
):
self
.
model_type
=
model_type
self
.
model_type
=
model_type
self
.
is_generation
=
model_type
==
"generation"
self
.
is_generation
=
model_type
==
"generation"
self
.
runtime
=
Runtim
e
(
self
.
engine
=
Engin
e
(
model_path
=
model_path
,
model_path
=
model_path
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
dtype
=
get_dtype_str
(
torch_dtype
),
dtype
=
get_dtype_str
(
torch_dtype
),
...
@@ -306,7 +306,7 @@ class SRTRunner:
...
@@ -306,7 +306,7 @@ class SRTRunner:
top_output_logprobs
=
[]
top_output_logprobs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
for
i
,
prompt
in
enumerate
(
prompts
):
for
i
,
prompt
in
enumerate
(
prompts
):
response
=
self
.
runtim
e
.
generate
(
response
=
self
.
engin
e
.
generate
(
prompt
,
prompt
,
lora_path
=
lora_paths
[
i
]
if
lora_paths
else
None
,
lora_path
=
lora_paths
[
i
]
if
lora_paths
else
None
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
@@ -314,7 +314,6 @@ class SRTRunner:
...
@@ -314,7 +314,6 @@ class SRTRunner:
logprob_start_len
=
0
,
logprob_start_len
=
0
,
top_logprobs_num
=
NUM_TOP_LOGPROBS
,
top_logprobs_num
=
NUM_TOP_LOGPROBS
,
)
)
response
=
json
.
loads
(
response
)
output_strs
.
append
(
response
[
"text"
])
output_strs
.
append
(
response
[
"text"
])
top_input_logprobs
.
append
(
top_input_logprobs
.
append
(
[
[
...
@@ -343,8 +342,7 @@ class SRTRunner:
...
@@ -343,8 +342,7 @@ class SRTRunner:
top_output_logprobs
=
top_output_logprobs
,
top_output_logprobs
=
top_output_logprobs
,
)
)
else
:
else
:
response
=
self
.
runtime
.
encode
(
prompts
)
response
=
self
.
engine
.
encode
(
prompts
)
response
=
json
.
loads
(
response
)
if
self
.
model_type
==
"embedding"
:
if
self
.
model_type
==
"embedding"
:
logits
=
[
x
[
"embedding"
]
for
x
in
response
]
logits
=
[
x
[
"embedding"
]
for
x
in
response
]
return
ModelOutput
(
embed_logits
=
logits
)
return
ModelOutput
(
embed_logits
=
logits
)
...
@@ -366,20 +364,18 @@ class SRTRunner:
...
@@ -366,20 +364,18 @@ class SRTRunner:
# the return value contains logprobs from prefill
# the return value contains logprobs from prefill
output_strs
=
[]
output_strs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
response
=
self
.
runtim
e
.
generate
(
response
=
self
.
engin
e
.
generate
(
prompts
,
prompts
,
lora_path
=
lora_paths
if
lora_paths
else
None
,
lora_path
=
lora_paths
if
lora_paths
else
None
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
)
)
response
=
json
.
loads
(
response
)
output_strs
=
[
r
[
"text"
]
for
r
in
response
]
output_strs
=
[
r
[
"text"
]
for
r
in
response
]
return
ModelOutput
(
return
ModelOutput
(
output_strs
=
output_strs
,
output_strs
=
output_strs
,
)
)
else
:
else
:
response
=
self
.
runtime
.
encode
(
prompts
)
response
=
self
.
engine
.
encode
(
prompts
)
response
=
json
.
loads
(
response
)
if
self
.
model_type
==
"embedding"
:
if
self
.
model_type
==
"embedding"
:
logits
=
[
x
[
"embedding"
]
for
x
in
response
]
logits
=
[
x
[
"embedding"
]
for
x
in
response
]
return
ModelOutput
(
embed_logits
=
logits
)
return
ModelOutput
(
embed_logits
=
logits
)
...
@@ -391,8 +387,8 @@ class SRTRunner:
...
@@ -391,8 +387,8 @@ class SRTRunner:
return
self
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
self
.
runtim
e
.
shutdown
()
self
.
engin
e
.
shutdown
()
del
self
.
runtim
e
del
self
.
engin
e
def
monkey_patch_gemma2_sdpa
():
def
monkey_patch_gemma2_sdpa
():
...
...
scripts/deprecated/test_jump_forward.py
View file @
61f42b57
...
@@ -4,7 +4,7 @@ from enum import Enum
...
@@ -4,7 +4,7 @@ from enum import Enum
from
pydantic
import
BaseModel
,
constr
from
pydantic
import
BaseModel
,
constr
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang.srt.constrained
import
build_regex_from_object
from
sglang.srt.constrained
.outlines_backend
import
build_regex_from_object
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
add_common_sglang_args_and_parse
,
select_sglang_backend
,
select_sglang_backend
,
...
...
test/lang/test_srt_backend.py
View file @
61f42b57
...
@@ -73,7 +73,7 @@ class TestSRTBackend(unittest.TestCase):
...
@@ -73,7 +73,7 @@ class TestSRTBackend(unittest.TestCase):
# Run twice to capture more bugs
# Run twice to capture more bugs
for
_
in
range
(
2
):
for
_
in
range
(
2
):
accuracy
,
latency
=
test_hellaswag_select
()
accuracy
,
latency
=
test_hellaswag_select
()
self
.
assertGreater
(
accuracy
,
0.7
1
)
self
.
assertGreater
(
accuracy
,
0.7
0
)
def
test_gen_min_new_tokens
(
self
):
def
test_gen_min_new_tokens
(
self
):
test_gen_min_new_tokens
()
test_gen_min_new_tokens
()
...
...
test/srt/models/test_qwen_models.py
View file @
61f42b57
...
@@ -71,7 +71,7 @@ class TestQwen2FP8(unittest.TestCase):
...
@@ -71,7 +71,7 @@ class TestQwen2FP8(unittest.TestCase):
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
print
(
metrics
)
print
(
metrics
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.
8
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.
79
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/models/test_reward_models.py
View file @
61f42b57
...
@@ -20,8 +20,8 @@ import torch
...
@@ -20,8 +20,8 @@ import torch
from
sglang.test.runners
import
HFRunner
,
SRTRunner
from
sglang.test.runners
import
HFRunner
,
SRTRunner
MODELS
=
[
MODELS
=
[
(
"LxzGordon/URM-LLaMa-3.1-8B"
,
1
,
3
e-2
),
(
"LxzGordon/URM-LLaMa-3.1-8B"
,
1
,
4
e-2
),
(
"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2"
,
1
,
3
e-2
),
(
"Skywork/Skywork-Reward-Llama-3.1-8B-v0.2"
,
1
,
4
e-2
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment