Unverified Commit c98e84c2 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Minor, Performance] Use torch.argmax for greedy sampling (#1589)

parent 9c064bf7
...@@ -43,7 +43,10 @@ class Sampler(nn.Module): ...@@ -43,7 +43,10 @@ class Sampler(nn.Module):
torch.isnan(probs), torch.full_like(probs, 1e-10), probs torch.isnan(probs), torch.full_like(probs, 1e-10), probs
) )
if global_server_args_dict["sampling_backend"] == "flashinfer": if sampling_info.top_ks.max().item() <= 1:
# Use torch.argmax if all requests use greedy sampling
batch_next_token_ids = torch.argmax(probs, -1)
elif global_server_args_dict["sampling_backend"] == "flashinfer":
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand( uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device (max_top_k_round, batch_size), device=probs.device
......
...@@ -27,11 +27,11 @@ class TestBenchServing(unittest.TestCase): ...@@ -27,11 +27,11 @@ class TestBenchServing(unittest.TestCase):
model=DEFAULT_MODEL_NAME_FOR_TEST, model=DEFAULT_MODEL_NAME_FOR_TEST,
num_prompts=200, num_prompts=200,
request_rate=float("inf"), request_rate=float("inf"),
other_server_args=["--max-running-requests", "10"],
dataset_name="sharegpt", dataset_name="sharegpt",
random_input_len=None, random_input_len=None,
random_output_len=None, random_output_len=None,
disable_stream=True, disable_stream=True,
other_server_args=["--max-running-requests", "10"],
) )
if is_in_ci(): if is_in_ci():
......
import json
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -39,6 +42,32 @@ class TestPyTorchSamplingBackend(unittest.TestCase): ...@@ -39,6 +42,32 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.65 assert metrics["score"] >= 0.65
def test_greedy(self):
response_single = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
).json()
response_batch = requests.post(
self.base_url + "/generate",
json={
"text": ["The capital of France is"] * 10,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
},
).json()
text = response_single["text"]
print(text)
for i in range(10):
assert response_batch[i]["text"] == text
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment