Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1cb4da5c
Unverified
Commit
1cb4da5c
authored
Aug 24, 2024
by
Ying Sheng
Committed by
GitHub
Aug 24, 2024
Browse files
[Fix] the issue of random order when input is a list (#1199)
parent
e61d13ac
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
20 deletions
+23
-20
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+7
-7
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-2
python/sglang/test/runners.py
python/sglang/test/runners.py
+1
-1
test/srt/models/test_embedding_models.py
test/srt/models/test_embedding_models.py
+13
-10
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
1cb4da5c
...
...
@@ -437,13 +437,13 @@ class TokenizerManager:
is_stream
=
hasattr
(
obj
,
"stream"
)
and
obj
.
stream
tasks
=
[
asyncio
.
create_task
(
gen
.
__anext__
())
for
gen
in
generators
]
output_list
=
[
]
output_list
=
[
None
]
*
len
(
tasks
)
while
tasks
:
done
,
_
=
await
asyncio
.
wait
(
tasks
,
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
done
:
gen
_index
=
tasks
.
index
(
task
)
cur
_index
=
tasks
.
index
(
task
)
try
:
result
=
task
.
result
()
...
...
@@ -451,14 +451,14 @@ class TokenizerManager:
if
is_stream
:
yield
result
else
:
output_list
.
append
(
result
)
output_list
[
result
[
"index"
]]
=
result
tasks
[
gen
_index
]
=
asyncio
.
create_task
(
generators
[
gen
_index
].
__anext__
()
tasks
[
cur
_index
]
=
asyncio
.
create_task
(
generators
[
cur
_index
].
__anext__
()
)
except
StopAsyncIteration
:
del
generators
[
gen
_index
]
del
tasks
[
gen
_index
]
del
generators
[
cur
_index
]
del
tasks
[
cur
_index
]
if
not
is_stream
:
yield
output_list
...
...
python/sglang/srt/server.py
View file @
1cb4da5c
...
...
@@ -591,7 +591,7 @@ class Runtime:
def
generate
(
self
,
prompt
:
str
,
prompt
:
Union
[
str
,
List
[
str
]]
,
sampling_params
:
Optional
[
Dict
]
=
None
,
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
False
,
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
,
...
...
@@ -612,7 +612,7 @@ class Runtime:
def
encode
(
self
,
prompt
:
str
,
prompt
:
Union
[
str
,
List
[
str
]]
,
):
json_data
=
{
"text"
:
prompt
,
...
...
python/sglang/test/runners.py
View file @
1cb4da5c
...
...
@@ -28,10 +28,10 @@ from sglang.srt.server import Runtime
DEFAULT_PROMPTS
=
[
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
"Apple is red. Banana is Yellow. "
*
800
+
"Apple is"
,
"The capital of the United Kindom is"
,
"Today is a sunny day and I like"
,
"AI is a field of computer science focused on"
,
"Apple is red. Banana is Yellow. "
*
800
+
"Apple is"
,
]
dirpath
=
os
.
path
.
dirname
(
__file__
)
...
...
test/srt/models/test_embedding_models.py
View file @
1cb4da5c
...
...
@@ -20,7 +20,7 @@ import torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
get_similarities
MODELS
=
[(
"intfloat/e5-mistral-7b-instruct"
,
1
)]
MODELS
=
[(
"intfloat/e5-mistral-7b-instruct"
,
1
,
0.2
)]
TORCH_DTYPES
=
[
torch
.
float16
]
...
...
@@ -32,6 +32,7 @@ class TestEmbeddingModels(unittest.TestCase):
model_path
,
tp_size
,
torch_dtype
,
long_context_tolerance
,
)
->
None
:
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
False
...
...
@@ -52,20 +53,22 @@ class TestEmbeddingModels(unittest.TestCase):
hf_logits
=
torch
.
Tensor
(
hf_outputs
.
embed_logits
[
i
])
srt_logits
=
torch
.
Tensor
(
srt_outputs
.
embed_logits
[
i
])
similarit
ies
=
torch
.
tensor
(
get_similarities
(
hf_logits
,
srt_logits
))
print
(
"
max
similarity diff"
,
torch
.
max
(
abs
(
similarit
ies
-
1
))
)
similarit
y
=
torch
.
tensor
(
get_similarities
(
hf_logits
,
srt_logits
))
print
(
"similarity diff"
,
abs
(
similarit
y
-
1
))
if
hf_logits
.
shape
[
0
]
<=
100
:
tolerance
=
1e-2
if
len
(
prompts
[
i
])
<=
1000
:
tolerance
=
1e-5
else
:
tolerance
=
long_context_tolerance
assert
torch
.
all
(
abs
(
similarit
ies
-
1
)
<
tolerance
abs
(
similarit
y
-
1
)
<
tolerance
),
"embeddings are not all close"
def
test_prefill_logits
(
self
):
for
model
,
tp_size
in
MODELS
:
for
model
,
tp_size
,
long_context_tolerance
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_prefill_logits
(
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
,
long_context_tolerance
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment