Unverified Commit 71133a04 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Auto Sync] Update sampling_batch_info.py (20250909) (#10212)


Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatarcctry <shiyang@x.ai>
parent 2cd94dd0
......@@ -67,28 +67,31 @@ class SamplingBatchInfo:
logit_bias: Optional[torch.Tensor] = None
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
def _get_global_server_args_dict(cls):
from sglang.srt.managers.schedule_batch import global_server_args_dict
return global_server_args_dict
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
global_server_args_dict = cls._get_global_server_args_dict()
reqs = batch.reqs
device = batch.device
temperatures = (
torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
)
.view(-1, 1)
.to(device, non_blocking=True)
)
temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
device=device,
).view(-1, 1)
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
)
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
).to(device, non_blocking=True)
[r.sampling_params.top_k for r in reqs], dtype=torch.int32, device=device
)
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)
logit_bias = None
if any(r.sampling_params.logit_bias is not None for r in reqs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment