Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
32f61443
"docs/vscode:/vscode.git/clone" did not exist on "84fccf579324661107c4bf407b94ace92a278339"
Unverified
Commit
32f61443
authored
Aug 11, 2024
by
Ying Sheng
Committed by
GitHub
Aug 12, 2024
Browse files
fix: Fix returned prefill logits and add output str test (#1046)
parent
fb1f28cb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
14 deletions
+33
-14
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+5
-0
python/sglang/test/runners.py
python/sglang/test/runners.py
+13
-8
test/srt/models/test_generation_models.py
test/srt/models/test_generation_models.py
+15
-6
No files found.
python/sglang/srt/layers/logits_processor.py
View file @
32f61443
...
@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
...
@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
tensor_model_parallel_all_gather
(
all_logits
)
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
all_logits
=
all_logits
[:,
:
self
.
config
.
vocab_size
].
float
()
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
all_logits
/=
self
.
config
.
final_logit_softcapping
all_logits
=
torch
.
tanh
(
all_logits
)
all_logits
*=
self
.
config
.
final_logit_softcapping
all_logprobs
=
all_logits
all_logprobs
=
all_logits
del
all_logits
,
hidden_states
del
all_logits
,
hidden_states
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
all_logprobs
[:]
=
torch
.
nn
.
functional
.
log_softmax
(
all_logprobs
,
dim
=-
1
)
...
...
python/sglang/test/runners.py
View file @
32f61443
...
@@ -26,9 +26,11 @@ from sglang.srt.server import Runtime
...
@@ -26,9 +26,11 @@ from sglang.srt.server import Runtime
from
sglang.srt.utils
import
is_generation_model
from
sglang.srt.utils
import
is_generation_model
DEFAULT_PROMPTS
=
[
DEFAULT_PROMPTS
=
[
"The capital of France is"
,
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
"The capital of the United Kindom is"
,
"The capital of the United Kindom is"
,
"Today is a sunny day and I like"
,
"Today is a sunny day and I like"
,
"AI is a field of computer science focused on"
,
]
]
NUM_TOP_LOGPROBS
=
5
NUM_TOP_LOGPROBS
=
5
...
@@ -43,10 +45,11 @@ def get_dtype_str(torch_dtype):
...
@@ -43,10 +45,11 @@ def get_dtype_str(torch_dtype):
@
dataclass
@
dataclass
class
ModelOutput
:
class
ModelOutput
:
output_strs
:
str
=
None
output_strs
:
List
[
str
]
=
None
top_input_logprobs
:
torch
.
Tensor
=
None
output_ids
:
List
[
int
]
=
None
top_output_logprobs
:
torch
.
Tensor
=
None
top_input_logprobs
:
List
[
torch
.
Tensor
]
=
None
embed_logits
:
torch
.
Tensor
=
None
top_output_logprobs
:
List
[
torch
.
Tensor
]
=
None
embed_logits
:
List
[
torch
.
Tensor
]
=
None
class
HFRunner
:
class
HFRunner
:
...
@@ -117,7 +120,9 @@ class HFRunner:
...
@@ -117,7 +120,9 @@ class HFRunner:
output_ids
=
self
.
model
.
generate
(
output_ids
=
self
.
model
.
generate
(
input_ids
,
do_sample
=
False
,
max_new_tokens
=
max_new_tokens
input_ids
,
do_sample
=
False
,
max_new_tokens
=
max_new_tokens
)
)
output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
[
0
]))
output_strs
.
append
(
self
.
tokenizer
.
decode
(
output_ids
[
0
][
len
(
input_ids
[
0
])
:])
)
logits
=
self
.
model
.
forward
(
input_ids
).
logits
[
0
]
logits
=
self
.
model
.
forward
(
input_ids
).
logits
[
0
]
logprobs
=
F
.
log_softmax
(
logprobs
=
F
.
log_softmax
(
...
@@ -145,7 +150,7 @@ class HFRunner:
...
@@ -145,7 +150,7 @@ class HFRunner:
def
forward
(
def
forward
(
self
,
self
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
max_new_tokens
=
64
,
max_new_tokens
=
8
,
):
):
self
.
in_queue
.
put
((
prompts
,
max_new_tokens
))
self
.
in_queue
.
put
((
prompts
,
max_new_tokens
))
return
self
.
out_queue
.
get
()
return
self
.
out_queue
.
get
()
...
@@ -184,7 +189,7 @@ class SRTRunner:
...
@@ -184,7 +189,7 @@ class SRTRunner:
def
forward
(
def
forward
(
self
,
self
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
prompts
:
Union
[
List
[
str
],
List
[
torch
.
Tensor
]]
=
DEFAULT_PROMPTS
,
max_new_tokens
=
64
,
max_new_tokens
=
8
,
):
):
if
self
.
is_generation_model
:
if
self
.
is_generation_model
:
# the return value contains logprobs from prefill
# the return value contains logprobs from prefill
...
...
test/srt/models/test_generation_models.py
View file @
32f61443
...
@@ -21,23 +21,25 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
...
@@ -21,23 +21,25 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
MODELS
=
[
MODELS
=
[
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
),
(
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
1
),
(
"google/gemma-2-2b"
,
1
),
]
]
TORCH_DTYPES
=
[
torch
.
float16
]
TORCH_DTYPES
=
[
torch
.
float16
]
class
Test
Causal
Models
(
unittest
.
TestCase
):
class
Test
Generation
Models
(
unittest
.
TestCase
):
def
assert_close_prefill_logits
(
def
assert_close_prefill_logits
_and_output_strs
(
self
,
self
,
prompts
,
prompts
,
model_path
,
model_path
,
tp_size
,
tp_size
,
torch_dtype
,
torch_dtype
,
max_new_tokens
,
)
->
None
:
)
->
None
:
with
HFRunner
(
with
HFRunner
(
model_path
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
True
model_path
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
True
)
as
hf_runner
:
)
as
hf_runner
:
hf_outputs
=
hf_runner
.
forward
(
prompts
)
hf_outputs
=
hf_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
)
with
SRTRunner
(
with
SRTRunner
(
model_path
,
model_path
,
...
@@ -45,7 +47,7 @@ class TestCausalModels(unittest.TestCase):
...
@@ -45,7 +47,7 @@ class TestCausalModels(unittest.TestCase):
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
is_generation_model
=
True
,
is_generation_model
=
True
,
)
as
srt_runner
:
)
as
srt_runner
:
srt_outputs
=
srt_runner
.
forward
(
prompts
)
srt_outputs
=
srt_runner
.
forward
(
prompts
,
max_new_tokens
=
max_new_tokens
)
for
i
in
range
(
len
(
prompts
)):
for
i
in
range
(
len
(
prompts
)):
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
hf_logprobs
=
torch
.
Tensor
(
hf_outputs
.
top_input_logprobs
[
i
])
...
@@ -56,11 +58,18 @@ class TestCausalModels(unittest.TestCase):
...
@@ -56,11 +58,18 @@ class TestCausalModels(unittest.TestCase):
abs
(
hf_logprobs
-
srt_logprobs
)
<
tolerance
abs
(
hf_logprobs
-
srt_logprobs
)
<
tolerance
),
f
"prefill logprobs not all close"
),
f
"prefill logprobs not all close"
assert
hf_outputs
.
output_strs
==
srt_outputs
.
output_strs
def
test_prefill_logits
(
self
):
def
test_prefill_logits
(
self
):
for
model
,
tp_size
in
MODELS
:
for
model
,
tp_size
in
MODELS
:
for
torch_dtype
in
TORCH_DTYPES
:
for
torch_dtype
in
TORCH_DTYPES
:
self
.
assert_close_prefill_logits
(
max_new_tokens
=
8
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
self
.
assert_close_prefill_logits_and_output_strs
(
DEFAULT_PROMPTS
,
model
,
tp_size
,
torch_dtype
,
max_new_tokens
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment