Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3bc01ac1
Commit
3bc01ac1
authored
Jun 03, 2024
by
Lianmin Zheng
Browse files
[Minor] improve code style
parent
9f009261
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
67 additions
and
17 deletions
+67
-17
benchmark/latency_throughput/bench_throughput.py
benchmark/latency_throughput/bench_throughput.py
+4
-4
benchmark/latency_throughput/test_latency.py
benchmark/latency_throughput/test_latency.py
+3
-3
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+54
-4
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+6
-6
No files found.
benchmark/latency_throughput/bench_throughput.py
View file @
3bc01ac1
...
@@ -149,12 +149,12 @@ async def send_request(
...
@@ -149,12 +149,12 @@ async def send_request(
"inputs"
:
prompt
,
"inputs"
:
prompt
,
"parameters"
:
params
,
"parameters"
:
params
,
}
}
elif
backend
==
"
x
infer"
:
elif
backend
==
"
g
infer"
:
pass
pass
else
:
else
:
raise
ValueError
(
f
"Unknown backend:
{
backend
}
"
)
raise
ValueError
(
f
"Unknown backend:
{
backend
}
"
)
if
backend
!=
"
x
infer"
:
if
backend
!=
"
g
infer"
:
timeout
=
aiohttp
.
ClientTimeout
(
total
=
3
*
3600
)
timeout
=
aiohttp
.
ClientTimeout
(
total
=
3
*
3600
)
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
async
with
aiohttp
.
ClientSession
(
timeout
=
timeout
)
as
session
:
while
True
:
while
True
:
...
@@ -172,7 +172,7 @@ async def send_request(
...
@@ -172,7 +172,7 @@ async def send_request(
print
(
output
)
print
(
output
)
else
:
else
:
import
grpc
import
grpc
from
xlm.proto
import
sampler_pb2
,
sampler_pb2_grpc
from
ginfer
import
sampler_pb2
,
sampler_pb2_grpc
api_url
=
api_url
.
replace
(
"http://"
,
""
).
replace
(
"/generate"
,
""
)
api_url
=
api_url
.
replace
(
"http://"
,
""
).
replace
(
"/generate"
,
""
)
sampler_channel
=
grpc
.
aio
.
insecure_channel
(
api_url
)
sampler_channel
=
grpc
.
aio
.
insecure_channel
(
api_url
)
...
@@ -283,7 +283,7 @@ if __name__ == "__main__":
...
@@ -283,7 +283,7 @@ if __name__ == "__main__":
"--backend"
,
"--backend"
,
type
=
str
,
type
=
str
,
default
=
"srt"
,
default
=
"srt"
,
choices
=
[
"vllm"
,
"tgi"
,
"srt"
,
"lightllm"
,
"
x
infer"
],
choices
=
[
"vllm"
,
"tgi"
,
"srt"
,
"lightllm"
,
"
g
infer"
],
)
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"localhost"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
...
...
benchmark/latency_throughput/test_latency.py
View file @
3bc01ac1
...
@@ -18,7 +18,7 @@ if __name__ == "__main__":
...
@@ -18,7 +18,7 @@ if __name__ == "__main__":
args
.
port
=
21000
args
.
port
=
21000
elif
args
.
backend
==
"lightllm"
:
elif
args
.
backend
==
"lightllm"
:
args
.
port
=
22000
args
.
port
=
22000
elif
args
.
backend
==
"
x
infer"
:
elif
args
.
backend
==
"
g
infer"
:
args
.
port
=
9988
args
.
port
=
9988
else
:
else
:
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
...
@@ -60,9 +60,9 @@ if __name__ == "__main__":
...
@@ -60,9 +60,9 @@ if __name__ == "__main__":
"max_tokens"
:
max_new_tokens
,
"max_tokens"
:
max_new_tokens
,
},
},
)
)
elif
args
.
backend
==
"
x
infer"
:
elif
args
.
backend
==
"
g
infer"
:
import
grpc
import
grpc
from
xlm.proto
import
sampler_pb2
,
sampler_pb2_grpc
from
ginfer
import
sampler_pb2
,
sampler_pb2_grpc
sampler_channel
=
grpc
.
insecure_channel
(
url
.
replace
(
"http://"
,
""
))
sampler_channel
=
grpc
.
insecure_channel
(
url
.
replace
(
"http://"
,
""
))
sampler
=
sampler_pb2_grpc
.
SamplerStub
(
sampler_channel
)
sampler
=
sampler_pb2_grpc
.
SamplerStub
(
sampler_channel
)
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
3bc01ac1
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
import
json
import
json
import
os
import
os
import
warnings
import
warnings
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
functools
from
typing
import
Optional
,
Union
,
AbstractSet
,
Collection
,
Literal
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
transformers
import
(
from
transformers
import
(
...
@@ -177,10 +178,57 @@ def get_processor(
...
@@ -177,10 +178,57 @@ def get_processor(
class
TiktokenTokenizer
:
class
TiktokenTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
def
__init__
(
self
,
tokenizer_path
):
import
xlm.tokenizers.tiktoken_wrapper
as
tiktoken_wrapper
import
tiktoken
tokenizer
=
tiktoken_wrapper
.
Encoding
.
from_xtok_json
(
"xtok-json"
,
tokenizer_path
)
PAT_STR_B
=
r
"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
name
=
"tmp-json"
with
open
(
tokenizer_path
,
"rb"
)
as
fin
:
tok_dict
=
json
.
load
(
fin
)
mergeable_ranks
=
{
bytes
(
item
[
"bytes"
]):
item
[
"token"
]
for
item
in
tok_dict
[
"regular_tokens"
]
}
special_tokens
=
{
bytes
(
item
[
"bytes"
]).
decode
():
item
[
"token"
]
for
item
in
tok_dict
[
"special_tokens"
]
}
assert
tok_dict
[
"word_split"
]
==
"V1"
kwargs
=
{
"name"
:
name
,
"pat_str"
:
tok_dict
.
get
(
"pat_str"
,
PAT_STR_B
),
"mergeable_ranks"
:
mergeable_ranks
,
"special_tokens"
:
special_tokens
,
}
if
"default_allowed_special"
in
tok_dict
:
default_allowed_special
=
set
(
[
bytes
(
bytes_list
).
decode
()
for
bytes_list
in
tok_dict
[
"default_allowed_special"
]]
)
else
:
default_allowed_special
=
None
if
"vocab_size"
in
tok_dict
:
kwargs
[
"explicit_n_vocab"
]
=
tok_dict
[
"vocab_size"
]
tokenizer
=
tiktoken
.
Encoding
(
**
kwargs
)
tokenizer
.
_default_allowed_special
=
default_allowed_special
or
set
()
def
encode_patched
(
self
,
text
:
str
,
*
,
allowed_special
:
Union
[
Literal
[
"all"
],
AbstractSet
[
str
]]
=
set
(),
# noqa: B006
disallowed_special
:
Union
[
Literal
[
"all"
],
Collection
[
str
]]
=
"all"
,
)
->
list
[
int
]:
if
isinstance
(
allowed_special
,
set
):
allowed_special
|=
self
.
_default_allowed_special
return
tiktoken
.
Encoding
.
encode
(
self
,
text
,
allowed_special
=
allowed_special
,
disallowed_special
=
disallowed_special
)
tokenizer
.
encode
=
functools
.
partial
(
encode_patched
,
tokenizer
)
# Convert to HF interface
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
eos_token_id
=
tokenizer
.
eos_token
self
.
eos_token_id
=
tokenizer
.
_special_tokens
[
"<|eos|>"
]
self
.
vocab_size
=
tokenizer
.
n_vocab
self
.
vocab_size
=
tokenizer
.
n_vocab
def
encode
(
self
,
x
,
add_special_tokens
=
False
):
def
encode
(
self
,
x
,
add_special_tokens
=
False
):
...
@@ -190,6 +238,8 @@ class TiktokenTokenizer:
...
@@ -190,6 +238,8 @@ class TiktokenTokenizer:
return
self
.
tokenizer
.
decode
(
x
)
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
if
isinstance
(
batch
[
0
],
int
):
batch
=
[[
x
]
for
x
in
batch
]
return
self
.
tokenizer
.
decode_batch
(
batch
)
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
convert_ids_to_tokens
(
self
,
index
):
def
convert_ids_to_tokens
(
self
,
index
):
...
...
python/sglang/test/test_utils.py
View file @
3bc01ac1
...
@@ -88,9 +88,9 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
...
@@ -88,9 +88,9 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
return
pred
return
pred
def
call_generate_
x
infer
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
def
call_generate_
g
infer
(
prompt
,
temperature
,
max_tokens
,
stop
=
None
,
url
=
None
):
import
grpc
import
grpc
from
xlm.proto
import
sampler_pb2
,
sampler_pb2_grpc
from
ginfer
import
sampler_pb2
,
sampler_pb2_grpc
sampler_channel
=
grpc
.
insecure_channel
(
url
.
replace
(
"http://"
,
""
))
sampler_channel
=
grpc
.
insecure_channel
(
url
.
replace
(
"http://"
,
""
))
sampler
=
sampler_pb2_grpc
.
SamplerStub
(
sampler_channel
)
sampler
=
sampler_pb2_grpc
.
SamplerStub
(
sampler_channel
)
...
@@ -255,7 +255,7 @@ def add_common_other_args_and_parse(parser):
...
@@ -255,7 +255,7 @@ def add_common_other_args_and_parse(parser):
"vllm"
,
"vllm"
,
"outlines"
,
"outlines"
,
"lightllm"
,
"lightllm"
,
"
x
infer"
,
"
g
infer"
,
"guidance"
,
"guidance"
,
"lmql"
,
"lmql"
,
"srt-raw"
,
"srt-raw"
,
...
@@ -276,7 +276,7 @@ def add_common_other_args_and_parse(parser):
...
@@ -276,7 +276,7 @@ def add_common_other_args_and_parse(parser):
"lightllm"
:
22000
,
"lightllm"
:
22000
,
"lmql"
:
23000
,
"lmql"
:
23000
,
"srt-raw"
:
30000
,
"srt-raw"
:
30000
,
"
x
infer"
:
9988
,
"
g
infer"
:
9988
,
}
}
args
.
port
=
default_port
.
get
(
args
.
backend
,
None
)
args
.
port
=
default_port
.
get
(
args
.
backend
,
None
)
return
args
return
args
...
@@ -312,8 +312,8 @@ def _get_call_generate(args):
...
@@ -312,8 +312,8 @@ def _get_call_generate(args):
return
partial
(
call_generate_vllm
,
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
/generate"
)
return
partial
(
call_generate_vllm
,
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
/generate"
)
elif
args
.
backend
==
"srt-raw"
:
elif
args
.
backend
==
"srt-raw"
:
return
partial
(
call_generate_srt_raw
,
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
/generate"
)
return
partial
(
call_generate_srt_raw
,
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
/generate"
)
elif
args
.
backend
==
"
x
infer"
:
elif
args
.
backend
==
"
g
infer"
:
return
partial
(
call_generate_
x
infer
,
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
)
return
partial
(
call_generate_
g
infer
,
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
"
)
elif
args
.
backend
==
"outlines"
:
elif
args
.
backend
==
"outlines"
:
return
partial
(
call_generate_outlines
,
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
/generate"
)
return
partial
(
call_generate_outlines
,
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
/generate"
)
elif
args
.
backend
==
"guidance"
:
elif
args
.
backend
==
"guidance"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment