Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
689ff588
Unverified
Commit
689ff588
authored
Sep 09, 2024
by
Ying Sheng
Committed by
GitHub
Sep 09, 2024
Browse files
[CI] Return output logprobs in unit test (#1361)
parent
a7c47e0f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
21 deletions
+73
-21
python/sglang/test/runners.py
python/sglang/test/runners.py
+42
-15
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+31
-6
No files found.
python/sglang/test/runners.py
View file @
689ff588
...
@@ -50,6 +50,12 @@ def get_dtype_str(torch_dtype):
...
@@ -50,6 +50,12 @@ def get_dtype_str(torch_dtype):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_top_logprobs
(
logits
,
k
):
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
logprobs
,
top_indices
=
torch
.
topk
(
logprobs
,
k
=
k
,
dim
=-
1
)
return
logprobs
@
dataclass
@
dataclass
class
ModelOutput
:
class
ModelOutput
:
output_strs
:
List
[
str
]
=
None
output_strs
:
List
[
str
]
=
None
...
@@ -108,7 +114,8 @@ class HFRunner:
...
@@ -108,7 +114,8 @@ class HFRunner:
if
prompts
is
not
None
:
if
prompts
is
not
None
:
if
self
.
is_generation
:
if
self
.
is_generation
:
output_strs
=
[]
output_strs
=
[]
prefill_logprobs
=
[]
top_input_logprobs
=
[]
top_output_logprobs
=
[]
for
p
in
prompts
:
for
p
in
prompts
:
if
isinstance
(
p
,
str
):
if
isinstance
(
p
,
str
):
input_ids
=
self
.
tokenizer
.
encode
(
input_ids
=
self
.
tokenizer
.
encode
(
...
@@ -117,32 +124,43 @@ class HFRunner:
...
@@ -117,32 +124,43 @@ class HFRunner:
else
:
else
:
input_ids
=
torch
.
tensor
([
p
],
device
=
"cuda"
)
input_ids
=
torch
.
tensor
([
p
],
device
=
"cuda"
)
output_ids
=
self
.
model
.
generate
(
outputs
=
self
.
model
.
generate
(
input_ids
,
do_sample
=
False
,
max_new_tokens
=
max_new_tokens
input_ids
,
do_sample
=
False
,
temperature
=
None
,
top_p
=
None
,
max_new_tokens
=
max_new_tokens
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
)
)
output_strs
.
append
(
output_strs
.
append
(
self
.
tokenizer
.
decode
(
output
_ids
[
0
][
len
(
input_ids
[
0
])
:])
self
.
tokenizer
.
decode
(
output
s
[
0
]
[
0
][
len
(
input_ids
[
0
])
:])
)
)
# outputs.scores: (num_token, 1, vocab_size)
top_output_logprobs
.
append
(
[
get_top_logprobs
(
logits
[
0
],
NUM_TOP_LOGPROBS
).
tolist
()
for
logits
in
outputs
.
scores
]
)
del
outputs
logits
=
self
.
model
.
forward
(
input_ids
).
logits
[
0
]
input_logits
=
self
.
model
.
forward
(
input_ids
).
logits
[
0
]
logprobs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
top_input_logprobs
.
append
(
logprobs
,
top_indices
=
torch
.
topk
(
get_top_logprobs
(
input_logits
,
NUM_TOP_LOGPROBS
).
tolist
()
logprobs
,
k
=
NUM_TOP_LOGPROBS
,
dim
=-
1
)
)
# print("index", top_indices)
del
input_logits
prefill_logprobs
.
append
(
logprobs
.
tolist
())
del
logits
del
logprobs
out_queue
.
put
(
out_queue
.
put
(
ModelOutput
(
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
prefill_logprobs
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
,
top_output_logprobs
=
top_output_logprobs
,
)
)
)
)
else
:
else
:
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
logits
=
self
.
model
.
encode
(
prompts
).
tolist
()
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
out_queue
.
put
(
ModelOutput
(
embed_logits
=
logits
))
def
forward
(
def
forward
(
...
@@ -194,6 +212,7 @@ class SRTRunner:
...
@@ -194,6 +212,7 @@ class SRTRunner:
# the return value contains logprobs from prefill
# the return value contains logprobs from prefill
output_strs
=
[]
output_strs
=
[]
top_input_logprobs
=
[]
top_input_logprobs
=
[]
top_output_logprobs
=
[]
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
sampling_params
=
{
"max_new_tokens"
:
max_new_tokens
,
"temperature"
:
0
}
for
prompt
in
prompts
:
for
prompt
in
prompts
:
response
=
self
.
runtime
.
generate
(
response
=
self
.
runtime
.
generate
(
...
@@ -219,9 +238,17 @@ class SRTRunner:
...
@@ -219,9 +238,17 @@ class SRTRunner:
]
]
]
]
)
)
top_output_logprobs
.
append
(
[
[
tup
[
0
]
for
tup
in
x
[:
NUM_TOP_LOGPROBS
]]
for
x
in
response
[
"meta_info"
][
"output_top_logprobs"
]
]
)
return
ModelOutput
(
return
ModelOutput
(
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
output_strs
=
output_strs
,
top_input_logprobs
=
top_input_logprobs
,
top_output_logprobs
=
top_output_logprobs
,
)
)
else
:
else
:
response
=
self
.
runtime
.
encode
(
prompts
)
response
=
self
.
runtime
.
encode
(
prompts
)
...
...
test/srt/models/test_generation_models.py
View file @
689ff588
...
@@ -21,9 +21,9 @@ import torch
...
@@ -21,9 +21,9 @@ import torch
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
from
sglang.test.runners
import
DEFAULT_PROMPTS
,
HFRunner
,
SRTRunner
MODELS
=
[
MODELS
=
[
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
,
1.1
,
3e-2
,
1
),
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
,
1.1
,
3e-2
,
4e-2
,
1
),
(
"google/gemma-2-2b"
,
1
,
3
,
3e-2
,
1
),
(
"google/gemma-2-2b"
,
1
,
3
,
3e-2
,
5e-2
,
1
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
1
,
None
,
6e-2
,
1
),
(
"Alibaba-NLP/gte-Qwen2-1.5B-instruct"
,
1
,
None
,
6e-2
,
4e-2
,
1
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
...
@@ -70,6 +70,7 @@ class TestGenerationModels(unittest.TestCase):
...
@@ -70,6 +70,7 @@ class TestGenerationModels(unittest.TestCase):
torch_dtype
,
torch_dtype
,
max_new_tokens
,
max_new_tokens
,
prefill_tolerance
,
prefill_tolerance
,
output_tolerance
,
rouge_threshold
,
rouge_threshold
,
long_context_tolerance
,
long_context_tolerance
,
)
->
None
:
)
->
None
:
...
@@ -89,15 +90,37 @@ class TestGenerationModels(unittest.TestCase):
...
@@ -89,15 +90,37 @@ class TestGenerationModels(unittest.TestCase):
srt_outputs
=
srt_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
)
srt_outputs
=
srt_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
)
for
i
in
range
(
len
(
prompts
)):
for
i
in
range
(
len
(
prompts
)):
# input logprobs comparison
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_input_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_input_logprobs
[
i
])
input_len
=
hf_logprobs
.
shape
[
0
]
print
(
"max_diff"
,
torch
.
max
(
abs
(
hf_logprobs
-
srt_logprobs
)))
print
(
if
hf_logprobs
.
shape
[
0
]
<=
100
:
"prefill logprobs max_diff"
,
torch
.
max
(
abs
(
hf_logprobs
-
srt_logprobs
))
)
if
input_len
<=
100
:
assert
torch
.
all
(
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
prefill_tolerance
abs
(
hf_logprobs
-
srt_logprobs
)
<
prefill_tolerance
),
f
"prefill logprobs are not all close with model_path=
{
model_path
}
prompts=
{
prompts
}
prefill_tolerance=
{
prefill_tolerance
}
"
),
f
"prefill logprobs are not all close with model_path=
{
model_path
}
prompts=
{
prompts
}
prefill_tolerance=
{
prefill_tolerance
}
"
# output logprobs comparison
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_output_logprobs
[
i
])
srt_logprobs
=
torch
.
Tensor
(
srt_outputs
.
top_output_logprobs
[
i
])
# print(
# "output logprobs diff",
# [
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
# for j in range(max_new_tokens)
# ],
# )
print
(
"output logprobs max_diff"
,
torch
.
max
(
abs
(
hf_logprobs
-
srt_logprobs
))
)
if
input_len
<=
100
:
assert
torch
.
all
(
abs
(
hf_logprobs
-
srt_logprobs
)
<
output_tolerance
),
f
"output logprobs are not all close with model_path=
{
model_path
}
prompts=
{
prompts
}
... output_tolerance=
{
output_tolerance
}
"
# output strings comparison
print
(
f
"hf_outputs.output_strs=
{
hf_outputs
.
output_strs
}
"
)
print
(
f
"hf_outputs.output_strs=
{
hf_outputs
.
output_strs
}
"
)
print
(
f
"srt_outputs.output_strs=
{
srt_outputs
.
output_strs
}
"
)
print
(
f
"srt_outputs.output_strs=
{
srt_outputs
.
output_strs
}
"
)
rouge_l_scores
=
calculate_rouge_l
(
rouge_l_scores
=
calculate_rouge_l
(
...
@@ -114,6 +137,7 @@ class TestGenerationModels(unittest.TestCase):
...
@@ -114,6 +137,7 @@ class TestGenerationModels(unittest.TestCase):
tp_size
,
tp_size
,
long_context_tolerance
,
long_context_tolerance
,
prefill_tolerance
,
prefill_tolerance
,
output_tolerance
,
rouge_threshold
,
rouge_threshold
,
)
in
MODELS
:
)
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
for
torch_dtype
in
TORCH_DTYPES
:
...
@@ -125,6 +149,7 @@ class TestGenerationModels(unittest.TestCase):
...
@@ -125,6 +149,7 @@ class TestGenerationModels(unittest.TestCase):
torch_dtype
,
torch_dtype
,
max_new_tokens
,
max_new_tokens
,
prefill_tolerance
=
prefill_tolerance
,
prefill_tolerance
=
prefill_tolerance
,
output_tolerance
=
output_tolerance
,
rouge_threshold
=
rouge_threshold
,
rouge_threshold
=
rouge_threshold
,
long_context_tolerance
=
long_context_tolerance
,
long_context_tolerance
=
long_context_tolerance
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment