Unverified Commit 71133a04 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Auto Sync] Update sampling_batch_info.py (20250909) (#10212)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarcctry <shiyang@x.ai>
parent 2cd94dd0
...@@ -67,28 +67,31 @@ class SamplingBatchInfo: ...@@ -67,28 +67,31 @@ class SamplingBatchInfo:
logit_bias: Optional[torch.Tensor] = None logit_bias: Optional[torch.Tensor] = None
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def _get_global_server_args_dict(cls):
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
return global_server_args_dict
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
global_server_args_dict = cls._get_global_server_args_dict()
reqs = batch.reqs reqs = batch.reqs
device = batch.device device = batch.device
temperatures = ( temperatures = torch.tensor(
torch.tensor( [r.sampling_params.temperature for r in reqs],
[r.sampling_params.temperature for r in reqs], dtype=torch.float,
dtype=torch.float, device=device,
) ).view(-1, 1)
.view(-1, 1)
.to(device, non_blocking=True)
)
top_ps = torch.tensor( top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
).to(device, non_blocking=True) )
top_ks = torch.tensor( top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32 [r.sampling_params.top_k for r in reqs], dtype=torch.int32, device=device
).to(device, non_blocking=True) )
min_ps = torch.tensor( min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
).to(device, non_blocking=True) )
logit_bias = None logit_bias = None
if any(r.sampling_params.logit_bias is not None for r in reqs): if any(r.sampling_params.logit_bias is not None for r in reqs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment