Unverified Commit f6084f63 authored by Yang Zheng's avatar Yang Zheng Committed by GitHub
Browse files

[Speculative Decoding] Move indices to device before filtering output (#10850)


Co-authored-by: default avatarYang Zheng(SW)(Alex) <you@example.com>
parent 9323a315
...@@ -120,6 +120,9 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase): ...@@ -120,6 +120,9 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
indices_of_seq_with_bonus_tokens) indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output) model_outputs.append(model_output)
# move indices to device to avoid stream sync
indices_of_seq_with_bonus_tokens = torch.tensor(
indices_of_seq_with_bonus_tokens, device=self.device)
filtered_model_outputs = self._filter_model_output( filtered_model_outputs = self._filter_model_output(
model_outputs, indices_of_seq_with_bonus_tokens) model_outputs, indices_of_seq_with_bonus_tokens)
return filtered_model_outputs, True return filtered_model_outputs, True
...@@ -189,7 +192,7 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase): ...@@ -189,7 +192,7 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
@staticmethod @staticmethod
def _filter_model_output( def _filter_model_output(
expanded_batch_outputs: List[SamplerOutput], expanded_batch_outputs: List[SamplerOutput],
output_indices_to_retain: List[int]) -> List[SamplerOutput]: output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]:
""" """
Filters the model output to include only the specified sequence Filters the model output to include only the specified sequence
outputs. This method contracts the expanded batch output from the outputs. This method contracts the expanded batch output from the
...@@ -199,8 +202,8 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase): ...@@ -199,8 +202,8 @@ class MultiStepWorker(ProposerWorkerBase, WorkerWrapperBase):
Args: Args:
expanded_batch_output (List[SamplerOutput]): The expanded output expanded_batch_output (List[SamplerOutput]): The expanded output
batch from the model. batch from the model.
output_indices_to_retain (List[int]): Indices of the model outputs output_indices_to_retain (torch.Tensor): Indices of the model
to retain. outputs to retain.
Returns: Returns:
List[SamplerOutput]: A list containing the filtered model List[SamplerOutput]: A list containing the filtered model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment