Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
23f05005
Unverified
Commit
23f05005
authored
Feb 06, 2024
by
Lianmin Zheng
Committed by
GitHub
Feb 06, 2024
Browse files
Format code & move functions (#155)
parent
a7334aee
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
94 additions
and
54 deletions
+94
-54
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+1
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+8
-2
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+1
-8
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+1
-0
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+22
-13
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+8
-4
python/sglang/srt/models/qwen.py
python/sglang/srt/models/qwen.py
+19
-12
python/sglang/srt/models/yivl.py
python/sglang/srt/models/yivl.py
+20
-7
python/sglang/srt/server.py
python/sglang/srt/server.py
+5
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-1
test/srt/test_httpserver_decode.py
test/srt/test_httpserver_decode.py
+2
-0
test/srt/test_httpserver_decode_stream.py
test/srt/test_httpserver_decode_stream.py
+3
-1
test/srt/test_openai_server.py
test/srt/test_openai_server.py
+2
-3
No files found.
python/sglang/lang/chat_template.py
View file @
23f05005
...
@@ -193,6 +193,7 @@ def match_chat_ml(model_path: str):
...
@@ -193,6 +193,7 @@ def match_chat_ml(model_path: str):
if
"qwen"
in
model_path
and
"chat"
in
model_path
:
if
"qwen"
in
model_path
and
"chat"
in
model_path
:
return
get_chat_template
(
"chatml"
)
return
get_chat_template
(
"chatml"
)
@
register_chat_template_matching_function
@
register_chat_template_matching_function
def
match_chat_yi
(
model_path
:
str
):
def
match_chat_yi
(
model_path
:
str
):
model_path
=
model_path
.
lower
()
model_path
=
model_path
.
lower
()
...
...
python/sglang/srt/layers/logits_processor.py
View file @
23f05005
...
@@ -64,13 +64,19 @@ class LogitsProcessor(nn.Module):
...
@@ -64,13 +64,19 @@ class LogitsProcessor(nn.Module):
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
arange
(
all_logprobs
.
shape
[
0
],
device
=
"cuda"
),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
torch
.
cat
([
input_ids
[
1
:],
torch
.
tensor
([
0
],
device
=
"cuda"
)]),
]
]
logprobs_cumsum
=
torch
.
cumsum
(
prefill_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
logprobs_cumsum
=
torch
.
cumsum
(
prefill_logprobs
,
dim
=
0
,
dtype
=
torch
.
float32
)
start
=
input_metadata
.
extend_start_loc
.
clone
()
start
=
input_metadata
.
extend_start_loc
.
clone
()
end
=
start
+
input_metadata
.
extend_seq_lens
-
2
end
=
start
+
input_metadata
.
extend_seq_lens
-
2
start
.
clamp_
(
min
=
0
,
max
=
prefill_logprobs
.
shape
[
0
]
-
1
)
start
.
clamp_
(
min
=
0
,
max
=
prefill_logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
prefill_logprobs
.
shape
[
0
]
-
1
)
end
.
clamp_
(
min
=
0
,
max
=
prefill_logprobs
.
shape
[
0
]
-
1
)
sum_logp
=
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
prefill_logprobs
[
start
]
sum_logp
=
(
logprobs_cumsum
[
end
]
-
logprobs_cumsum
[
start
]
+
prefill_logprobs
[
start
]
)
normalized_logprobs
=
sum_logp
/
(
normalized_logprobs
=
sum_logp
/
(
(
input_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
(
input_metadata
.
extend_seq_lens
-
1
).
clamp
(
min
=
1
)
)
)
...
...
python/sglang/srt/layers/radix_attention.py
View file @
23f05005
...
@@ -13,14 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -13,14 +13,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
class
RadixAttention
(
nn
.
Module
):
class
RadixAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
):
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_q_head_num
=
num_heads
self
.
tp_k_head_num
=
num_kv_heads
self
.
tp_k_head_num
=
num_kv_heads
...
...
python/sglang/srt/managers/io_struct.py
View file @
23f05005
...
@@ -100,6 +100,7 @@ class BatchStrOut:
...
@@ -100,6 +100,7 @@ class BatchStrOut:
class
FlushCacheReq
:
class
FlushCacheReq
:
pass
pass
@
dataclass
@
dataclass
class
DetokenizeReqInput
:
class
DetokenizeReqInput
:
input_ids
:
List
[
int
]
input_ids
:
List
[
int
]
python/sglang/srt/managers/router/model_rpc.py
View file @
23f05005
...
@@ -11,8 +11,8 @@ import rpyc
...
@@ -11,8 +11,8 @@ import rpyc
import
torch
import
torch
from
rpyc.utils.classic
import
obtain
from
rpyc.utils.classic
import
obtain
from
rpyc.utils.server
import
ThreadedServer
from
rpyc.utils.server
import
ThreadedServer
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.jump_forward
import
JumpForwardCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
BatchTokenIDOut
,
BatchTokenIDOut
,
...
@@ -391,8 +391,12 @@ class ModelRpcServer(rpyc.Service):
...
@@ -391,8 +391,12 @@ class ModelRpcServer(rpyc.Service):
logprobs
=
None
logprobs
=
None
if
batch
.
extend_num_tokens
!=
0
:
if
batch
.
extend_num_tokens
!=
0
:
# Forward
# Forward
logits
,
(
prefill_logprobs
,
normalized_logprobs
,
last_logprobs
)
=
(
logits
,
(
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_logprob
)
prefill_logprobs
,
normalized_logprobs
,
last_logprobs
,
)
=
self
.
model_runner
.
forward
(
batch
,
ForwardMode
.
EXTEND
,
batch
.
return_logprob
)
)
if
prefill_logprobs
is
not
None
:
if
prefill_logprobs
is
not
None
:
logprobs
=
prefill_logprobs
.
cpu
().
tolist
()
logprobs
=
prefill_logprobs
.
cpu
().
tolist
()
...
@@ -407,7 +411,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -407,7 +411,9 @@ class ModelRpcServer(rpyc.Service):
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
if
last_logprobs
is
not
None
:
if
last_logprobs
is
not
None
:
last_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
reqs
)),
next_token_ids
].
cpu
().
tolist
()
last_logprobs
=
(
last_logprobs
[
torch
.
arange
(
len
(
reqs
)),
next_token_ids
].
cpu
().
tolist
()
)
# Check finish condition
# Check finish condition
pt
=
0
pt
=
0
...
@@ -482,7 +488,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -482,7 +488,9 @@ class ModelRpcServer(rpyc.Service):
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
reqs
=
batch
.
reqs
reqs
=
batch
.
reqs
if
last_logprobs
is
not
None
:
if
last_logprobs
is
not
None
:
last_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
reqs
)),
next_token_ids
].
tolist
()
last_logprobs
=
last_logprobs
[
torch
.
arange
(
len
(
reqs
)),
next_token_ids
].
tolist
()
# Check finish condition
# Check finish condition
for
i
,
(
req
,
next_tok_id
)
in
enumerate
(
zip
(
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_tok_id
)
in
enumerate
(
zip
(
reqs
,
next_token_ids
)):
...
@@ -620,15 +628,16 @@ class ModelRpcClient:
...
@@ -620,15 +628,16 @@ class ModelRpcClient:
self
.
step
=
async_wrap
(
"step"
)
self
.
step
=
async_wrap
(
"step"
)
def
start_model_process
(
port
):
def
_init_service
(
port
):
def
_init_service
(
port
):
t
=
ThreadedServer
(
t
=
Threaded
Server
(
ModelRpc
Server
(
),
ModelRpcServer
()
,
port
=
port
,
port
=
port
,
protocol_config
=
{
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
}
,
protocol_config
=
{
"allow_pickle"
:
True
,
"sync_request_timeout"
:
1800
},
)
)
t
.
start
(
)
t
.
start
()
def
start_model_process
(
port
):
proc
=
multiprocessing
.
Process
(
target
=
_init_service
,
args
=
(
port
,))
proc
=
multiprocessing
.
Process
(
target
=
_init_service
,
args
=
(
port
,))
proc
.
start
()
proc
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
1
)
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
23f05005
...
@@ -17,8 +17,8 @@ from vllm.model_executor.model_loader import _set_default_torch_dtype
...
@@ -17,8 +17,8 @@ from vllm.model_executor.model_loader import _set_default_torch_dtype
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
from
vllm.model_executor.parallel_utils.parallel_state
import
initialize_model_parallel
import
sglang
import
sglang
QUANTIONCONFIG_MAPPING
=
{
'awq'
:
AWQConfig
,
'
gptq
'
:
GPTQConfig
}
QUANTIONCONFIG_MAPPING
=
{
"awq"
:
AWQConfig
,
"
gptq
"
:
GPTQConfig
}
logger
=
logging
.
getLogger
(
"model_runner"
)
logger
=
logging
.
getLogger
(
"model_runner"
)
...
@@ -283,9 +283,13 @@ class ModelRunner:
...
@@ -283,9 +283,13 @@ class ModelRunner:
self
.
model_config
.
hf_config
,
"quantization_config"
,
None
self
.
model_config
.
hf_config
,
"quantization_config"
,
None
)
)
if
hf_quant_config
is
not
None
:
if
hf_quant_config
is
not
None
:
quant_config_class
=
QUANTIONCONFIG_MAPPING
.
get
(
hf_quant_config
[
'quant_method'
])
quant_config_class
=
QUANTIONCONFIG_MAPPING
.
get
(
hf_quant_config
[
"quant_method"
]
)
if
quant_config_class
is
None
:
if
quant_config_class
is
None
:
raise
ValueError
(
f
"Unsupported quantization method:
{
hf_quant_config
[
'quant_method'
]
}
"
)
raise
ValueError
(
f
"Unsupported quantization method:
{
hf_quant_config
[
'quant_method'
]
}
"
)
quant_config
=
quant_config_class
.
from_config
(
hf_quant_config
)
quant_config
=
quant_config_class
.
from_config
(
hf_quant_config
)
logger
.
info
(
f
"quant_config:
{
quant_config
}
"
)
logger
.
info
(
f
"quant_config:
{
quant_config
}
"
)
linear_method
=
quant_config
.
get_linear_method
()
linear_method
=
quant_config
.
get_linear_method
()
...
...
python/sglang/srt/models/qwen.py
View file @
23f05005
...
@@ -42,14 +42,14 @@ class QWenMLP(nn.Module):
...
@@ -42,14 +42,14 @@ class QWenMLP(nn.Module):
2
*
[
intermediate_size
],
2
*
[
intermediate_size
],
bias
=
False
,
bias
=
False
,
gather_output
=
False
,
gather_output
=
False
,
linear_method
=
linear_method
linear_method
=
linear_method
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
intermediate_size
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
linear_method
=
linear_method
linear_method
=
linear_method
,
)
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
raise
ValueError
(
...
@@ -74,7 +74,7 @@ class QWenAttention(nn.Module):
...
@@ -74,7 +74,7 @@ class QWenAttention(nn.Module):
layer_id
:
int
=
0
,
layer_id
:
int
=
0
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -86,18 +86,18 @@ class QWenAttention(nn.Module):
...
@@ -86,18 +86,18 @@ class QWenAttention(nn.Module):
# pylint: disable=invalid-name
# pylint: disable=invalid-name
self
.
c_attn
=
QKVParallelLinear
(
self
.
c_attn
=
QKVParallelLinear
(
hidden_size
,
hidden_size
,
self
.
head_dim
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
True
,
bias
=
True
,
linear_method
=
linear_method
linear_method
=
linear_method
,
)
)
self
.
c_proj
=
RowParallelLinear
(
self
.
c_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
False
,
bias
=
False
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
linear_method
=
linear_method
linear_method
=
linear_method
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -143,12 +143,16 @@ class QWenBlock(nn.Module):
...
@@ -143,12 +143,16 @@ class QWenBlock(nn.Module):
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
linear_method
=
linear_method
linear_method
=
linear_method
,
)
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
QWenMLP
(
config
.
hidden_size
,
config
.
intermediate_size
//
2
,
linear_method
=
linear_method
)
self
.
mlp
=
QWenMLP
(
config
.
hidden_size
,
config
.
intermediate_size
//
2
,
linear_method
=
linear_method
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -186,7 +190,10 @@ class QWenModel(nn.Module):
...
@@ -186,7 +190,10 @@ class QWenModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
h
=
nn
.
ModuleList
(
self
.
h
=
nn
.
ModuleList
(
[
QWenBlock
(
config
,
i
,
linear_method
=
linear_method
)
for
i
in
range
(
config
.
num_hidden_layers
)]
[
QWenBlock
(
config
,
i
,
linear_method
=
linear_method
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
...
python/sglang/srt/models/yivl.py
View file @
23f05005
...
@@ -4,14 +4,17 @@ from typing import List, Optional
...
@@ -4,14 +4,17 @@ from typing import List, Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
sglang.srt.models.llava
import
(
LlavaLlamaForCausalLM
,
clip_vision_embed_forward
,
monkey_path_clip_vision_embed_forward
,
)
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
transformers
import
CLIPVisionModel
,
LlavaConfig
from
vllm.model_executor.weight_utils
import
(
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
hf_model_weights_iterator
,
hf_model_weights_iterator
,
)
)
from
sglang.srt.models.llava
import
LlavaLlamaForCausalLM
,
clip_vision_embed_forward
,
monkey_path_clip_vision_embed_forward
class
YiVLForCausalLM
(
LlavaLlamaForCausalLM
):
class
YiVLForCausalLM
(
LlavaLlamaForCausalLM
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
@@ -19,7 +22,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
...
@@ -19,7 +22,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
super
().
__init__
(
self
.
config
)
super
().
__init__
(
self
.
config
)
self
.
multi_modal_projector
=
YiVLMultiModalProjector
(
self
.
config
)
self
.
multi_modal_projector
=
YiVLMultiModalProjector
(
self
.
config
)
self
.
vision_tower_subfolder
=
self
.
config
.
mm_vision_tower
.
replace
(
"./"
,
""
)
# Everything after "./"
self
.
vision_tower_subfolder
=
self
.
config
.
mm_vision_tower
.
replace
(
"./"
,
""
)
# Everything after "./"
def
load_weights
(
def
load_weights
(
self
,
self
,
...
@@ -30,7 +35,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
...
@@ -30,7 +35,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
):
):
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
self
.
vision_tower
=
CLIPVisionModel
.
from_pretrained
(
model_name_or_path
,
torch_dtype
=
torch
.
float16
,
subfolder
=
self
.
vision_tower_subfolder
model_name_or_path
,
torch_dtype
=
torch
.
float16
,
subfolder
=
self
.
vision_tower_subfolder
,
).
cuda
()
).
cuda
()
self
.
vision_tower
.
eval
()
self
.
vision_tower
.
eval
()
...
@@ -80,14 +87,19 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
...
@@ -80,14 +87,19 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
monkey_path_clip_vision_embed_forward
()
monkey_path_clip_vision_embed_forward
()
class
YiVLMultiModalProjector
(
nn
.
Module
):
class
YiVLMultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlavaConfig
):
def
__init__
(
self
,
config
:
LlavaConfig
):
super
().
__init__
()
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
config
.
vision_config
.
hidden_size
,
config
.
text_config
.
hidden_size
)
self
.
linear_1
=
nn
.
Linear
(
config
.
vision_config
.
hidden_size
,
config
.
text_config
.
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
config
.
text_config
.
hidden_size
)
self
.
ln_1
=
nn
.
LayerNorm
(
config
.
text_config
.
hidden_size
)
self
.
act
=
nn
.
GELU
()
self
.
act
=
nn
.
GELU
()
self
.
linear_2
=
nn
.
Linear
(
config
.
text_config
.
hidden_size
,
config
.
text_config
.
hidden_size
)
self
.
linear_2
=
nn
.
Linear
(
config
.
text_config
.
hidden_size
,
config
.
text_config
.
hidden_size
)
self
.
ln_2
=
nn
.
LayerNorm
(
config
.
text_config
.
hidden_size
)
self
.
ln_2
=
nn
.
LayerNorm
(
config
.
text_config
.
hidden_size
)
def
forward
(
self
,
image_features
):
def
forward
(
self
,
image_features
):
...
@@ -98,4 +110,5 @@ class YiVLMultiModalProjector(nn.Module):
...
@@ -98,4 +110,5 @@ class YiVLMultiModalProjector(nn.Module):
hidden_states
=
self
.
ln_2
(
hidden_states
)
hidden_states
=
self
.
ln_2
(
hidden_states
)
return
hidden_states
return
hidden_states
EntryClass
=
YiVLForCausalLM
\ No newline at end of file
EntryClass
=
YiVLForCausalLM
python/sglang/srt/server.py
View file @
23f05005
...
@@ -63,6 +63,7 @@ chat_template_name = None
...
@@ -63,6 +63,7 @@ chat_template_name = None
# FIXME: Remove this once we drop support for pydantic 1.x
# FIXME: Remove this once we drop support for pydantic 1.x
IS_PYDANTIC_1
=
int
(
pydantic
.
VERSION
.
split
(
"."
)[
0
])
==
1
IS_PYDANTIC_1
=
int
(
pydantic
.
VERSION
.
split
(
"."
)[
0
])
==
1
def
jsonify_pydantic_model
(
obj
:
BaseModel
):
def
jsonify_pydantic_model
(
obj
:
BaseModel
):
if
IS_PYDANTIC_1
:
if
IS_PYDANTIC_1
:
return
obj
.
json
(
ensure_ascii
=
False
)
return
obj
.
json
(
ensure_ascii
=
False
)
...
@@ -165,7 +166,7 @@ async def v1_completions(raw_request: Request):
...
@@ -165,7 +166,7 @@ async def v1_completions(raw_request: Request):
prompt_tokens
=
content
[
"meta_info"
][
"prompt_tokens"
]
prompt_tokens
=
content
[
"meta_info"
][
"prompt_tokens"
]
completion_tokens
=
content
[
"meta_info"
][
"completion_tokens"
]
completion_tokens
=
content
[
"meta_info"
][
"completion_tokens"
]
if
not
stream_buffer
:
# The first chunk
if
not
stream_buffer
:
# The first chunk
if
request
.
echo
:
if
request
.
echo
:
# Prepend prompt in response text.
# Prepend prompt in response text.
text
=
request
.
prompt
+
text
text
=
request
.
prompt
+
text
...
@@ -219,7 +220,9 @@ async def v1_completions(raw_request: Request):
...
@@ -219,7 +220,9 @@ async def v1_completions(raw_request: Request):
token_logprob_pos
=
prompt_tokens
token_logprob_pos
=
prompt_tokens
logprobs
=
(
logprobs
=
(
await
make_openai_style_logprobs
(
ret
[
"meta_info"
][
"token_logprob"
][
token_logprob_pos
:])
await
make_openai_style_logprobs
(
ret
[
"meta_info"
][
"token_logprob"
][
token_logprob_pos
:]
)
if
request
.
logprobs
is
not
None
if
request
.
logprobs
is
not
None
else
None
else
None
)
)
...
...
python/sglang/srt/server_args.py
View file @
23f05005
...
@@ -114,7 +114,7 @@ class ServerArgs:
...
@@ -114,7 +114,7 @@ class ServerArgs:
"--max-prefill-num-token"
,
"--max-prefill-num-token"
,
type
=
int
,
type
=
int
,
default
=
ServerArgs
.
max_prefill_num_token
,
default
=
ServerArgs
.
max_prefill_num_token
,
help
=
"The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length."
help
=
"The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--tp-size"
,
"--tp-size"
,
...
...
python/sglang/srt/utils.py
View file @
23f05005
...
@@ -259,4 +259,4 @@ def load_image(image_file):
...
@@ -259,4 +259,4 @@ def load_image(image_file):
else
:
else
:
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
image
=
Image
.
open
(
BytesIO
(
base64
.
b64decode
(
image_file
)))
return
image
return
image
\ No newline at end of file
test/srt/test_httpserver_decode.py
View file @
23f05005
...
@@ -12,6 +12,7 @@ import argparse
...
@@ -12,6 +12,7 @@ import argparse
import
requests
import
requests
def
test_decode
(
url
,
return_logprob
):
def
test_decode
(
url
,
return_logprob
):
response
=
requests
.
post
(
response
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
...
@@ -27,6 +28,7 @@ def test_decode(url, return_logprob):
...
@@ -27,6 +28,7 @@ def test_decode(url, return_logprob):
)
)
print
(
response
.
json
())
print
(
response
.
json
())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
...
...
test/srt/test_httpserver_decode_stream.py
View file @
23f05005
...
@@ -12,6 +12,7 @@ import json
...
@@ -12,6 +12,7 @@ import json
import
requests
import
requests
def
test_decode_stream
(
url
,
return_logprob
):
def
test_decode_stream
(
url
,
return_logprob
):
response
=
requests
.
post
(
response
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
...
@@ -39,7 +40,7 @@ def test_decode_stream(url, return_logprob):
...
@@ -39,7 +40,7 @@ def test_decode_stream(url, return_logprob):
assert
data
[
"meta_info"
][
"prompt_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"prompt_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"token_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"token_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"normalized_prompt_logprob"
]
is
not
None
assert
data
[
"meta_info"
][
"normalized_prompt_logprob"
]
is
not
None
if
prev
==
0
:
# Skip prompt logprobs
if
prev
==
0
:
# Skip prompt logprobs
prev
=
data
[
"meta_info"
][
"prompt_tokens"
]
prev
=
data
[
"meta_info"
][
"prompt_tokens"
]
for
token_txt
,
_
,
logprob
in
data
[
"meta_info"
][
"token_logprob"
][
prev
:]:
for
token_txt
,
_
,
logprob
in
data
[
"meta_info"
][
"token_logprob"
][
prev
:]:
print
(
f
"
{
token_txt
}
\t
{
logprob
}
"
,
flush
=
True
)
print
(
f
"
{
token_txt
}
\t
{
logprob
}
"
,
flush
=
True
)
...
@@ -50,6 +51,7 @@ def test_decode_stream(url, return_logprob):
...
@@ -50,6 +51,7 @@ def test_decode_stream(url, return_logprob):
prev
=
len
(
output
)
prev
=
len
(
output
)
print
(
""
)
print
(
""
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
...
...
test/srt/test_openai_server.py
View file @
23f05005
...
@@ -64,9 +64,8 @@ def test_completion_stream(args, echo, logprobs):
...
@@ -64,9 +64,8 @@ def test_completion_stream(args, echo, logprobs):
first
=
False
first
=
False
if
logprobs
:
if
logprobs
:
print
(
print
(
f
"
{
r
.
choices
[
0
].
text
:
12
s
}
\t
"
f
"
{
r
.
choices
[
0
].
text
:
12
s
}
\t
"
f
"
{
r
.
choices
[
0
].
logprobs
.
token_logprobs
}
"
,
f
"
{
r
.
choices
[
0
].
logprobs
.
token_logprobs
}
"
,
flush
=
True
,
flush
=
True
)
)
else
:
else
:
print
(
r
.
choices
[
0
].
text
,
end
=
""
,
flush
=
True
)
print
(
r
.
choices
[
0
].
text
,
end
=
""
,
flush
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment