Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c51020cf
"vscode:/vscode.git/clone" did not exist on "e61b68e01c8b10e9bcd3235baa8d163a1e4f3379"
Unverified
Commit
c51020cf
authored
Feb 11, 2024
by
Lianmin Zheng
Committed by
GitHub
Feb 11, 2024
Browse files
Fix the chat template for llava-v1.6-34b & format code (#177)
parent
50afed4e
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
92 additions
and
37 deletions
+92
-37
python/sglang/api.py
python/sglang/api.py
+1
-0
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+24
-9
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+18
-1
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+3
-3
python/sglang/lang/tracer.py
python/sglang/lang/tracer.py
+1
-0
python/sglang/srt/backend_config.py
python/sglang/srt/backend_config.py
+1
-0
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+2
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+10
-4
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+9
-3
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+1
-0
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+1
-1
python/sglang/srt/models/mistral.py
python/sglang/srt/models/mistral.py
+1
-0
python/sglang/srt/models/mixtral.py
python/sglang/srt/models/mixtral.py
+9
-7
python/sglang/srt/models/yivl.py
python/sglang/srt/models/yivl.py
+1
-0
python/sglang/srt/sampling_params.py
python/sglang/srt/sampling_params.py
+1
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+4
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-3
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+1
-0
python/sglang/utils.py
python/sglang/utils.py
+2
-2
test/lang/test_srt_backend.py
test/lang/test_srt_backend.py
+1
-0
No files found.
python/sglang/api.py
View file @
c51020cf
"""Public API"""
"""Public API"""
import
re
import
re
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
...
...
python/sglang/backend/runtime_endpoint.py
View file @
c51020cf
...
@@ -19,7 +19,9 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -19,7 +19,9 @@ class RuntimeEndpoint(BaseBackend):
self
.
base_url
=
base_url
self
.
base_url
=
base_url
self
.
auth_token
=
auth_token
self
.
auth_token
=
auth_token
res
=
http_request
(
self
.
base_url
+
"/get_model_info"
,
auth_token
=
self
.
auth_token
)
res
=
http_request
(
self
.
base_url
+
"/get_model_info"
,
auth_token
=
self
.
auth_token
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
self
.
model_info
=
res
.
json
()
self
.
model_info
=
res
.
json
()
...
@@ -37,7 +39,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -37,7 +39,7 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
prefix_str
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
json
=
{
"text"
:
prefix_str
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
auth_token
=
self
.
auth_token
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
@@ -45,14 +47,16 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -45,14 +47,16 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/generate"
,
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
json
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}},
auth_token
=
self
.
auth_token
auth_token
=
self
.
auth_token
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
def
fill_image
(
self
,
s
:
StreamExecutor
):
def
fill_image
(
self
,
s
:
StreamExecutor
):
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
def
generate
(
def
generate
(
...
@@ -82,7 +86,9 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -82,7 +86,9 @@ class RuntimeEndpoint(BaseBackend):
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
obj
=
res
.
json
()
obj
=
res
.
json
()
comp
=
obj
[
"text"
]
comp
=
obj
[
"text"
]
return
comp
,
obj
[
"meta_info"
]
return
comp
,
obj
[
"meta_info"
]
...
@@ -115,7 +121,12 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -115,7 +121,12 @@ class RuntimeEndpoint(BaseBackend):
data
[
"stream"
]
=
True
data
[
"stream"
]
=
True
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
response
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
stream
=
True
,
auth_token
=
self
.
auth_token
)
response
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
stream
=
True
,
auth_token
=
self
.
auth_token
,
)
pos
=
0
pos
=
0
incomplete_text
=
""
incomplete_text
=
""
...
@@ -145,7 +156,9 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -145,7 +156,9 @@ class RuntimeEndpoint(BaseBackend):
# Cache common prefix
# Cache common prefix
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
data
=
{
"text"
:
s
.
text_
,
"sampling_params"
:
{
"max_new_tokens"
:
0
}}
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
prompt_len
=
res
.
json
()[
"meta_info"
][
"prompt_tokens"
]
prompt_len
=
res
.
json
()[
"meta_info"
][
"prompt_tokens"
]
...
@@ -157,7 +170,9 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -157,7 +170,9 @@ class RuntimeEndpoint(BaseBackend):
"logprob_start_len"
:
max
(
prompt_len
-
2
,
0
),
"logprob_start_len"
:
max
(
prompt_len
-
2
,
0
),
}
}
self
.
_add_images
(
s
,
data
)
self
.
_add_images
(
s
,
data
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
res
=
http_request
(
self
.
base_url
+
"/generate"
,
json
=
data
,
auth_token
=
self
.
auth_token
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
obj
=
res
.
json
()
obj
=
res
.
json
()
normalized_prompt_logprob
=
[
normalized_prompt_logprob
=
[
...
@@ -172,7 +187,7 @@ class RuntimeEndpoint(BaseBackend):
...
@@ -172,7 +187,7 @@ class RuntimeEndpoint(BaseBackend):
res
=
http_request
(
res
=
http_request
(
self
.
base_url
+
"/concate_and_append_request"
,
self
.
base_url
+
"/concate_and_append_request"
,
json
=
{
"src_rids"
:
src_rids
,
"dst_rid"
:
dst_rid
},
json
=
{
"src_rids"
:
src_rids
,
"dst_rid"
:
dst_rid
},
auth_token
=
self
.
auth_token
auth_token
=
self
.
auth_token
,
)
)
assert
res
.
status_code
==
200
assert
res
.
status_code
==
200
...
...
python/sglang/lang/chat_template.py
View file @
c51020cf
...
@@ -116,6 +116,21 @@ register_chat_template(
...
@@ -116,6 +116,21 @@ register_chat_template(
)
)
register_chat_template
(
ChatTemplate
(
name
=
"chatml-llava"
,
default_system_prompt
=
"Answer the questions."
,
role_prefix_and_suffix
=
{
"system"
:
(
"<|im_start|>system
\n
"
,
"
\n
<|im_end|>
\n
"
),
"user"
:
(
"<|im_start|>user
\n
"
,
"
\n
<|im_end|>
\n
"
),
"assistant"
:
(
"<|im_start|>assistant
\n
"
,
"
\n
<|im_end|>
\n
"
),
},
style
=
ChatTemplateStyle
.
PLAIN
,
stop_str
=
(
"<|im_end|>"
,),
image_token
=
" <image>
\n
"
,
)
)
register_chat_template
(
register_chat_template
(
ChatTemplate
(
ChatTemplate
(
name
=
"vicuna_v1.1"
,
name
=
"vicuna_v1.1"
,
...
@@ -168,7 +183,7 @@ register_chat_template(
...
@@ -168,7 +183,7 @@ register_chat_template(
def
match_vicuna
(
model_path
:
str
):
def
match_vicuna
(
model_path
:
str
):
if
"vicuna"
in
model_path
.
lower
():
if
"vicuna"
in
model_path
.
lower
():
return
get_chat_template
(
"vicuna_v1.1"
)
return
get_chat_template
(
"vicuna_v1.1"
)
if
"llava"
in
model_path
.
lower
():
if
"llava
-v1.5
"
in
model_path
.
lower
():
return
get_chat_template
(
"vicuna_v1.1"
)
return
get_chat_template
(
"vicuna_v1.1"
)
...
@@ -192,6 +207,8 @@ def match_chat_ml(model_path: str):
...
@@ -192,6 +207,8 @@ def match_chat_ml(model_path: str):
return
get_chat_template
(
"chatml"
)
return
get_chat_template
(
"chatml"
)
if
"qwen"
in
model_path
and
"chat"
in
model_path
:
if
"qwen"
in
model_path
and
"chat"
in
model_path
:
return
get_chat_template
(
"chatml"
)
return
get_chat_template
(
"chatml"
)
if
"llava-v1.6-34b"
in
model_path
:
return
get_chat_template
(
"chatml-llava"
)
@
register_chat_template_matching_function
@
register_chat_template_matching_function
...
...
python/sglang/lang/ir.py
View file @
c51020cf
...
@@ -74,9 +74,9 @@ class SglSamplingParams:
...
@@ -74,9 +74,9 @@ class SglSamplingParams:
)
)
return
{
return
{
"max_tokens_to_sample"
:
self
.
max_new_tokens
,
"max_tokens_to_sample"
:
self
.
max_new_tokens
,
"stop_sequences"
:
self
.
stop
"stop_sequences"
:
(
if
isinstance
(
self
.
stop
,
(
list
,
tuple
))
self
.
stop
if
isinstance
(
self
.
stop
,
(
list
,
tuple
))
else
[
self
.
stop
]
else
[
self
.
stop
]
,
)
,
"temperature"
:
self
.
temperature
,
"temperature"
:
self
.
temperature
,
"top_p"
:
self
.
top_p
,
"top_p"
:
self
.
top_p
,
"top_k"
:
self
.
top_k
,
"top_k"
:
self
.
top_k
,
...
...
python/sglang/lang/tracer.py
View file @
c51020cf
"""Tracing a program."""
"""Tracing a program."""
import
uuid
import
uuid
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
...
...
python/sglang/srt/backend_config.py
View file @
c51020cf
"""
"""
Backend configurations, may vary with different serving platforms.
Backend configurations, may vary with different serving platforms.
"""
"""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
...
python/sglang/srt/conversation.py
View file @
c51020cf
...
@@ -366,7 +366,8 @@ def generate_chat_conv(
...
@@ -366,7 +366,8 @@ def generate_chat_conv(
if
content
.
type
==
"text"
:
if
content
.
type
==
"text"
:
real_content
+=
content
.
text
real_content
+=
content
.
text
elif
content
.
type
==
"image_url"
:
elif
content
.
type
==
"image_url"
:
real_content
+=
"<image>"
# NOTE: Only works for llava
real_content
+=
"<image>
\n
"
conv
.
append_image
(
content
.
image_url
.
url
)
conv
.
append_image
(
content
.
image_url
.
url
)
conv
.
append_message
(
conv
.
roles
[
0
],
real_content
)
conv
.
append_message
(
conv
.
roles
[
0
],
real_content
)
elif
msg_role
==
"assistant"
:
elif
msg_role
==
"assistant"
:
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
c51020cf
...
@@ -31,6 +31,7 @@ from sglang.srt.utils import (
...
@@ -31,6 +31,7 @@ from sglang.srt.utils import (
is_multimodal_model
,
is_multimodal_model
,
set_random_seed
,
set_random_seed
,
)
)
from
vllm.logger
import
_default_handler
as
vllm_default_handler
logger
=
logging
.
getLogger
(
"model_rpc"
)
logger
=
logging
.
getLogger
(
"model_rpc"
)
...
@@ -50,6 +51,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -50,6 +51,9 @@ class ModelRpcServer(rpyc.Service):
self
.
tp_size
=
server_args
.
tp_size
self
.
tp_size
=
server_args
.
tp_size
self
.
schedule_heuristic
=
server_args
.
schedule_heuristic
self
.
schedule_heuristic
=
server_args
.
schedule_heuristic
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
self
.
disable_regex_jump_forward
=
server_args
.
disable_regex_jump_forward
vllm_default_handler
.
setLevel
(
level
=
getattr
(
logging
,
server_args
.
log_level
.
upper
())
)
# Init model and tokenizer
# Init model and tokenizer
self
.
model_config
=
ModelConfig
(
self
.
model_config
=
ModelConfig
(
...
@@ -83,9 +87,11 @@ class ModelRpcServer(rpyc.Service):
...
@@ -83,9 +87,11 @@ class ModelRpcServer(rpyc.Service):
self
.
max_num_running_seq
=
self
.
max_total_num_token
//
2
self
.
max_num_running_seq
=
self
.
max_total_num_token
//
2
self
.
max_prefill_num_token
=
max
(
self
.
max_prefill_num_token
=
max
(
self
.
model_config
.
context_len
,
self
.
model_config
.
context_len
,
(
self
.
max_total_num_token
//
6
self
.
max_total_num_token
//
6
if
server_args
.
max_prefill_num_token
is
None
if
server_args
.
max_prefill_num_token
is
None
else
server_args
.
max_prefill_num_token
,
else
server_args
.
max_prefill_num_token
),
)
)
self
.
int_token_logit_bias
=
torch
.
tensor
(
self
.
int_token_logit_bias
=
torch
.
tensor
(
get_int_token_logit_bias
(
self
.
tokenizer
,
self
.
model_config
.
vocab_size
)
get_int_token_logit_bias
(
self
.
tokenizer
,
self
.
model_config
.
vocab_size
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
c51020cf
...
@@ -112,7 +112,9 @@ class InputMetadata:
...
@@ -112,7 +112,9 @@ class InputMetadata:
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
batch_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
workspace_buffer
=
torch
.
empty
(
32
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
workspace_buffer
=
torch
.
empty
(
32
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
"cuda"
)
if
(
if
(
self
.
forward_mode
==
ForwardMode
.
PREFILL
self
.
forward_mode
==
ForwardMode
.
PREFILL
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
or
self
.
forward_mode
==
ForwardMode
.
EXTEND
...
@@ -121,7 +123,9 @@ class InputMetadata:
...
@@ -121,7 +123,9 @@ class InputMetadata:
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
(
self
.
batch_size
+
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
qo_indptr
[
1
:]
=
torch
.
cumsum
(
self
.
extend_seq_lens
,
dim
=
0
)
self
.
prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
self
.
prefill_wrapper
=
BatchPrefillWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
self
.
prefill_wrapper
.
begin_forward
(
self
.
prefill_wrapper
.
begin_forward
(
self
.
qo_indptr
,
self
.
qo_indptr
,
self
.
kv_indptr
,
self
.
kv_indptr
,
...
@@ -131,7 +135,9 @@ class InputMetadata:
...
@@ -131,7 +135,9 @@ class InputMetadata:
self
.
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
self
.
model_runner
.
model_config
.
num_key_value_heads
//
tp_size
,
)
)
else
:
else
:
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
self
.
decode_wrapper
=
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
"NHD"
)
self
.
decode_wrapper
.
begin_forward
(
self
.
decode_wrapper
.
begin_forward
(
self
.
kv_indptr
,
self
.
kv_indptr
,
self
.
kv_indices
,
self
.
kv_indices
,
...
...
python/sglang/srt/memory_pool.py
View file @
c51020cf
"""Memory pool."""
"""Memory pool."""
import
logging
import
logging
import
torch
import
torch
...
...
python/sglang/srt/models/llava.py
View file @
c51020cf
"""Inference-only LLaVa model compatible with HuggingFace weights."""
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
numpy
as
np
import
numpy
as
np
...
@@ -269,7 +270,6 @@ class LlavaLlamaForCausalLM(nn.Module):
...
@@ -269,7 +270,6 @@ class LlavaLlamaForCausalLM(nn.Module):
raise
ValueError
(
f
"Unexpected select feature:
{
self
.
select_feature
}
"
)
raise
ValueError
(
f
"Unexpected select feature:
{
self
.
select_feature
}
"
)
# load mm_projector
# load mm_projector
# TODO: support TP?
projector_weights
=
{
projector_weights
=
{
"model.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.mm_projector.0"
:
"multi_modal_projector.linear_1"
,
"model.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
"model.mm_projector.2"
:
"multi_modal_projector.linear_2"
,
...
...
python/sglang/srt/models/mistral.py
View file @
c51020cf
"""Inference-only Mistral model."""
"""Inference-only Mistral model."""
from
sglang.srt.models.llama2
import
LlamaForCausalLM
from
sglang.srt.models.llama2
import
LlamaForCausalLM
...
...
python/sglang/srt/models/mixtral.py
View file @
c51020cf
...
@@ -97,6 +97,7 @@ class MixtralMoE(nn.Module):
...
@@ -97,6 +97,7 @@ class MixtralMoE(nn.Module):
self
.
experts
=
nn
.
ModuleList
(
self
.
experts
=
nn
.
ModuleList
(
[
[
(
MixtralMLP
(
MixtralMLP
(
self
.
num_total_experts
,
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -105,6 +106,7 @@ class MixtralMoE(nn.Module):
...
@@ -105,6 +106,7 @@ class MixtralMoE(nn.Module):
)
)
if
idx
in
self
.
expert_indicies
if
idx
in
self
.
expert_indicies
else
None
else
None
)
for
idx
in
range
(
self
.
num_total_experts
)
for
idx
in
range
(
self
.
num_total_experts
)
]
]
)
)
...
...
python/sglang/srt/models/yivl.py
View file @
c51020cf
"""Inference-only Yi-VL model."""
"""Inference-only Yi-VL model."""
import
os
import
os
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
...
...
python/sglang/srt/sampling_params.py
View file @
c51020cf
"""Sampling parameters for text generation."""
"""Sampling parameters for text generation."""
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
_SAMPLING_EPS
=
1e-6
_SAMPLING_EPS
=
1e-6
...
...
python/sglang/srt/server.py
View file @
c51020cf
"""SRT: SGLang Runtime"""
"""SRT: SGLang Runtime"""
import
asyncio
import
asyncio
import
json
import
json
import
multiprocessing
as
mp
import
multiprocessing
as
mp
...
@@ -493,7 +494,7 @@ def launch_server(server_args, pipe_finish_writer):
...
@@ -493,7 +494,7 @@ def launch_server(server_args, pipe_finish_writer):
# Warmup
# Warmup
try
:
try
:
print
(
"Warmup..."
,
flush
=
True
)
#
print("Warmup...", flush=True)
res
=
requests
.
post
(
res
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
json
=
{
json
=
{
...
@@ -505,8 +506,8 @@ def launch_server(server_args, pipe_finish_writer):
...
@@ -505,8 +506,8 @@ def launch_server(server_args, pipe_finish_writer):
},
},
timeout
=
60
,
timeout
=
60
,
)
)
print
(
f
"Warmup done. model response:
{
res
.
json
()[
'text'
]
}
"
)
#
print(f"Warmup done. model response: {res.json()['text']}")
print
(
"="
*
20
,
"Server is ready"
,
"="
*
20
,
flush
=
True
)
#
print("=" * 20, "Server is ready", "=" * 20, flush=True)
except
requests
.
exceptions
.
RequestException
as
e
:
except
requests
.
exceptions
.
RequestException
as
e
:
if
pipe_finish_writer
is
not
None
:
if
pipe_finish_writer
is
not
None
:
pipe_finish_writer
.
send
(
str
(
e
))
pipe_finish_writer
.
send
(
str
(
e
))
...
...
python/sglang/srt/utils.py
View file @
c51020cf
...
@@ -122,7 +122,7 @@ def handle_port_init(
...
@@ -122,7 +122,7 @@ def handle_port_init(
# first check on server port
# first check on server port
if
not
check_port
(
port
):
if
not
check_port
(
port
):
new_port
=
alloc_usable_network_port
(
1
,
used_list
=
[
port
])[
0
]
new_port
=
alloc_usable_network_port
(
1
,
used_list
=
[
port
])[
0
]
print
(
f
"Port
{
port
}
is not available
, using
{
new_port
}
instead."
)
print
(
f
"
WARNING:
Port
{
port
}
is not available
. Use
{
new_port
}
instead."
)
port
=
new_port
port
=
new_port
# then we check on additional ports
# then we check on additional ports
...
@@ -157,8 +157,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
...
@@ -157,8 +157,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
ss
=
tokenizer
.
decode
([
t_id
]).
strip
()
ss
=
tokenizer
.
decode
([
t_id
]).
strip
()
if
not
(
ss
.
isdigit
()
or
len
(
ss
)
==
0
or
t_id
==
tokenizer
.
eos_token_id
):
if
not
(
ss
.
isdigit
()
or
len
(
ss
)
==
0
or
t_id
==
tokenizer
.
eos_token_id
):
logit_bias
[
t_id
]
=
-
1e5
logit_bias
[
t_id
]
=
-
1e5
# else:
# print(ss, t_id)
return
logit_bias
return
logit_bias
...
...
python/sglang/test/test_utils.py
View file @
c51020cf
"""Common utilities for testing and benchmarking"""
"""Common utilities for testing and benchmarking"""
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
from
sglang.backend.openai
import
OpenAI
from
sglang.backend.openai
import
OpenAI
...
...
python/sglang/utils.py
View file @
c51020cf
...
@@ -22,7 +22,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
...
@@ -22,7 +22,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):
if
torch
.
cuda
.
current_device
()
!=
gpu_id
:
if
torch
.
cuda
.
current_device
()
!=
gpu_id
:
print
(
print
(
f
"WARN: current device is not
{
gpu_id
}
, but
{
torch
.
cuda
.
current_device
()
}
, "
,
f
"WARN
ING
: current device is not
{
gpu_id
}
, but
{
torch
.
cuda
.
current_device
()
}
, "
,
"which may cause useless memory allocation for torch CUDA context."
,
"which may cause useless memory allocation for torch CUDA context."
,
)
)
...
@@ -95,7 +95,7 @@ def http_request(url, json=None, stream=False, auth_token=None):
...
@@ -95,7 +95,7 @@ def http_request(url, json=None, stream=False, auth_token=None):
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
)
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
)
headers
=
{
headers
=
{
"Content-Type"
:
"application/json"
,
"Content-Type"
:
"application/json"
,
"Authentication"
:
f
"Bearer
{
auth_token
}
"
"Authentication"
:
f
"Bearer
{
auth_token
}
"
,
}
}
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
)
return
requests
.
post
(
url
,
json
=
json
,
stream
=
True
,
headers
=
headers
)
else
:
else
:
...
...
test/lang/test_srt_backend.py
View file @
c51020cf
"""
"""
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""
"""
import
json
import
json
import
unittest
import
unittest
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment