Commit 371e5c76 authored by laibao's avatar laibao
Browse files

fix: 修复 expanded sampling metadata 对 numpy/array-like 输入不兼容导致崩溃

  - repeat_counts/CPU 元数据为 numpy/array-like 时会在 repeat_interleave/.to() 崩溃
  - 统一转换为 CPU torch.Tensor 后再扩展并拷到 GPU
parent 952f0347
...@@ -735,21 +735,34 @@ class InputBatch: ...@@ -735,21 +735,34 @@ class InputBatch:
self, repeat_counts: torch.Tensor self, repeat_counts: torch.Tensor
) -> SamplingMetadata: ) -> SamplingMetadata:
num_reqs = self.num_reqs num_reqs = self.num_reqs
repeat_counts_cpu = repeat_counts # `repeat_counts` is expected to be a CPU torch tensor, but some
# call sites may pass a NumPy array (or other array-likes). Normalize
# to a CPU tensor to keep downstream ops (e.g. repeat_interleave)
# consistent and avoid hard crashes.
if isinstance(repeat_counts, torch.Tensor):
repeat_counts_cpu = repeat_counts.to(device="cpu")
else:
repeat_counts_cpu = torch.as_tensor(repeat_counts, device="cpu")
all_greedy = self.all_greedy all_greedy = self.all_greedy
all_random = self.all_random all_random = self.all_random
# For reject-sampling optimization, force greedy sampling to keep # For reject-sampling optimization, force greedy sampling to keep
# rejection sampler assumptions (per-request shapes) intact. # rejection sampler assumptions (per-request shapes) intact.
def _expand_cpu_to_gpu( def _expand_cpu_to_gpu(
t: Optional[torch.Tensor], t: Optional[object],
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if t is None: if t is None:
return None return None
# `t` should be a CPU torch tensor, but can be a NumPy array view
# (e.g. created via `tensor.numpy()`). Convert if needed.
if isinstance(t, torch.Tensor):
base = t[:num_reqs] base = t[:num_reqs]
if repeat_counts_cpu is not None: elif isinstance(t, np.ndarray):
base = torch.from_numpy(t[:num_reqs])
else:
base = torch.as_tensor(t, device="cpu")[:num_reqs]
base = base.repeat_interleave(repeat_counts_cpu, dim=0) base = base.repeat_interleave(repeat_counts_cpu, dim=0)
return base.to(device=self.device, return base.to(device=self.device,
dtype=dtype if dtype is not None else None, dtype=dtype if dtype is not None else None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment