Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eff468dd
Unverified
Commit
eff468dd
authored
Nov 12, 2024
by
Xiaoyu Zhang
Committed by
GitHub
Nov 12, 2024
Browse files
fix test_embedding_models prompt length too long's bug (#2015)
parent
a1bd7190
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
2 deletions
+23
-2
test/srt/models/test_embedding_models.py
test/srt/models/test_embedding_models.py
+23
-2
No files found.
test/srt/models/test_embedding_models.py
View file @
eff468dd
...
...
@@ -17,6 +17,7 @@ import multiprocessing as mp
import
unittest
import
torch
from
transformers
import
AutoConfig
,
AutoTokenizer
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.test_utils
import
get_similarities
...
...
@@ -34,6 +35,24 @@ class TestEmbeddingModels(unittest.TestCase):
def
setUpClass
(
cls
):
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
def
_truncate_prompts
(
self
,
prompts
,
model_path
):
config
=
AutoConfig
.
from_pretrained
(
model_path
)
max_length
=
getattr
(
config
,
"max_position_embeddings"
,
2048
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
)
truncated_prompts
=
[]
for
prompt
in
prompts
:
tokens
=
tokenizer
(
prompt
,
return_tensors
=
"pt"
,
truncation
=
False
)
if
len
(
tokens
.
input_ids
[
0
])
>
max_length
:
truncated_text
=
tokenizer
.
decode
(
tokens
.
input_ids
[
0
][:
max_length
-
1
],
skip_special_tokens
=
True
)
truncated_prompts
.
append
(
truncated_text
)
else
:
truncated_prompts
.
append
(
prompt
)
return
truncated_prompts
def
assert_close_prefill_logits
(
self
,
prompts
,
...
...
@@ -42,12 +61,14 @@ class TestEmbeddingModels(unittest.TestCase):
torch_dtype
,
prefill_tolerance
,
)
->
None
:
truncated_prompts
=
self
.
_truncate_prompts
(
prompts
,
model_path
)
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
model_type
=
"embedding"
,
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
)
hf_outputs
=
hf_runner
.
forward
(
truncated_
prompts
)
with
SRTRunner
(
model_path
,
...
...
@@ -55,7 +76,7 @@ class TestEmbeddingModels(unittest.TestCase):
torch_dtype
=
torch_dtype
,
model_type
=
"embedding"
,
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
)
srt_outputs
=
srt_runner
.
forward
(
truncated_
prompts
)
for
i
in
range
(
len
(
prompts
)):
hf_logits
=
torch
.
Tensor
(
hf_outputs
.
embed_logits
[
i
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment