Commit 33e33aa7 authored by laibao's avatar laibao
Browse files

“重构 InputBatch,移除 _expand_logitsprocs 方法并简化 logits 处理。

parent a0d556fe
......@@ -736,8 +736,6 @@ class InputBatch:
for repeat in repeat_list:
row_offsets.append(total_rows)
total_rows += int(repeat)
expanded_logitsprocs = self._expand_logitsprocs(
repeat_list, row_offsets, total_rows)
expanded_output_token_ids: list[list[int]] = []
expanded_bad_words_token_ids: dict[int, list[list[int]]] = {}
expanded_generators: dict[int, torch.Generator] = {}
......@@ -782,7 +780,7 @@ class InputBatch:
allowed_token_ids_mask=_expand_cpu_to_gpu(
allowed_token_ids_mask, dtype=torch.bool),
bad_words_token_ids=expanded_bad_words_token_ids,
logitsprocs=expanded_logitsprocs,
logitsprocs=self.logitsprocs,
)
@property
......@@ -827,110 +825,6 @@ class InputBatch:
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
def _expand_logitsprocs(
self, repeat_list: list[int], row_offsets: list[int], total_rows: int
) -> LogitsProcessorManager:
"""Expand per-request logits processors to per-token shape for
repeat_counts-expanded batches."""
def _expand_min_p(proc: MinPLogitsProcessor) -> MinPLogitsProcessor:
expanded = MinPLogitsProcessor(
max_num_reqs=total_rows,
pin_memory=self.pin_memory,
device=self.device)
if total_rows == 0:
expanded.min_p_count = 0
expanded.min_p = expanded.min_p_device[:0]
expanded.min_p.unsqueeze_(1)
return expanded
base_min_p_cpu = torch.from_numpy(proc.min_p_cpu[:self.num_reqs])
repeats = torch.tensor(repeat_list, dtype=torch.int64)
expanded_min_p_cpu = base_min_p_cpu.repeat_interleave(repeats)
expanded.min_p_cpu_tensor[:total_rows].copy_(expanded_min_p_cpu)
expanded.min_p = expanded.min_p_device[:total_rows]
expanded.min_p.copy_(expanded.min_p_cpu_tensor[:total_rows],
non_blocking=True)
expanded.min_p.unsqueeze_(1)
expanded.min_p_count = int((expanded_min_p_cpu != 0).sum().item())
return expanded
def _expand_logit_bias(
proc: LogitBiasLogitsProcessor) -> LogitBiasLogitsProcessor:
expanded = LogitBiasLogitsProcessor(pin_memory=self.pin_memory,
device=self.device)
# Preserve biases dict for truthiness and reuse.
expanded.biases = proc.biases
if not proc.biases or total_rows == 0:
return expanded
req_indices: list[int] = []
tok_indices: list[int] = []
bias_vals: list[float] = []
for req_idx, lb in proc.biases.items():
repeat = repeat_list[req_idx]
if repeat <= 0:
continue
start = row_offsets[req_idx]
tok_ids = list(lb.keys())
biases = list(lb.values())
for row in range(start, start + repeat):
req_indices.extend([row] * len(tok_ids))
tok_indices.extend(tok_ids)
bias_vals.extend(biases)
if bias_vals:
expanded.bias_tensor = expanded._device_tensor(
bias_vals, torch.float32)
expanded.logits_slice = (
expanded._device_tensor(req_indices, torch.int32),
expanded._device_tensor(tok_indices, torch.int32),
)
return expanded
def _expand_min_tokens(
proc: MinTokensLogitsProcessor) -> MinTokensLogitsProcessor:
expanded = MinTokensLogitsProcessor(pin_memory=self.pin_memory,
device=self.device)
expanded.min_toks = proc.min_toks
if not proc.min_toks or total_rows == 0:
return expanded
req_indices: list[int] = []
tok_indices: list[int] = []
for req_idx, (_, _, stop_tok_ids) in proc.min_toks.items():
repeat = repeat_list[req_idx]
if repeat <= 0:
continue
start = row_offsets[req_idx]
stop_ids = list(stop_tok_ids)
for row in range(start, start + repeat):
req_indices.extend([row] * len(stop_ids))
tok_indices.extend(stop_ids)
if tok_indices:
expanded.logits_slice = (
expanded._device_tensor(req_indices, torch.int32),
expanded._device_tensor(tok_indices, torch.int32),
)
return expanded
expanded_argmax: list = []
for proc in self.logitsprocs.argmax_invariant:
if isinstance(proc, MinPLogitsProcessor):
expanded_argmax.append(_expand_min_p(proc))
else:
expanded_argmax.append(proc)
expanded_non_argmax: list = []
for proc in self.logitsprocs.non_argmax_invariant:
if isinstance(proc, LogitBiasLogitsProcessor):
expanded_non_argmax.append(_expand_logit_bias(proc))
elif isinstance(proc, MinTokensLogitsProcessor):
expanded_non_argmax.append(_expand_min_tokens(proc))
else:
expanded_non_argmax.append(proc)
return LogitsProcessorManager(
argmax_invariant=expanded_argmax,
non_argmax_invariant=expanded_non_argmax,
)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment