Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c98e84c2
Unverified
Commit
c98e84c2
authored
Oct 06, 2024
by
Ying Sheng
Committed by
GitHub
Oct 06, 2024
Browse files
[Minor, Performance] Use torch.argmax for greedy sampling (#1589)
parent
9c064bf7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
2 deletions
+34
-2
python/sglang/srt/layers/sampler.py
python/sglang/srt/layers/sampler.py
+4
-1
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+1
-1
test/srt/test_pytorch_sampling_backend.py
test/srt/test_pytorch_sampling_backend.py
+29
-0
No files found.
python/sglang/srt/layers/sampler.py
View file @
c98e84c2
...
@@ -43,7 +43,10 @@ class Sampler(nn.Module):
...
@@ -43,7 +43,10 @@ class Sampler(nn.Module):
torch
.
isnan
(
probs
),
torch
.
full_like
(
probs
,
1e-10
),
probs
torch
.
isnan
(
probs
),
torch
.
full_like
(
probs
,
1e-10
),
probs
)
)
if
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
if
sampling_info
.
top_ks
.
max
().
item
()
<=
1
:
# Use torch.argmax if all requests use greedy sampling
batch_next_token_ids
=
torch
.
argmax
(
probs
,
-
1
)
elif
global_server_args_dict
[
"sampling_backend"
]
==
"flashinfer"
:
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
max_top_k_round
,
batch_size
=
32
,
probs
.
shape
[
0
]
uniform_samples
=
torch
.
rand
(
uniform_samples
=
torch
.
rand
(
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
(
max_top_k_round
,
batch_size
),
device
=
probs
.
device
...
...
test/srt/test_bench_serving.py
View file @
c98e84c2
...
@@ -27,11 +27,11 @@ class TestBenchServing(unittest.TestCase):
...
@@ -27,11 +27,11 @@ class TestBenchServing(unittest.TestCase):
model
=
DEFAULT_MODEL_NAME_FOR_TEST
,
model
=
DEFAULT_MODEL_NAME_FOR_TEST
,
num_prompts
=
200
,
num_prompts
=
200
,
request_rate
=
float
(
"inf"
),
request_rate
=
float
(
"inf"
),
other_server_args
=
[
"--max-running-requests"
,
"10"
],
dataset_name
=
"sharegpt"
,
dataset_name
=
"sharegpt"
,
random_input_len
=
None
,
random_input_len
=
None
,
random_output_len
=
None
,
random_output_len
=
None
,
disable_stream
=
True
,
disable_stream
=
True
,
other_server_args
=
[
"--max-running-requests"
,
"10"
],
)
)
if
is_in_ci
():
if
is_in_ci
():
...
...
test/srt/test_pytorch_sampling_backend.py
View file @
c98e84c2
import
json
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
...
@@ -39,6 +42,32 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
...
@@ -39,6 +42,32 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
assert
metrics
[
"score"
]
>=
0.65
def
test_greedy
(
self
):
response_single
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
"The capital of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
},
},
).
json
()
response_batch
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
[
"The capital of France is"
]
*
10
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
},
},
).
json
()
text
=
response_single
[
"text"
]
print
(
text
)
for
i
in
range
(
10
):
assert
response_batch
[
i
][
"text"
]
==
text
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment