Unverified Commit e73ed0f1 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Bugfix] Fix type annotations in CPU model runner (#4256)

parent 296cdf8a
...@@ -73,7 +73,8 @@ class CPUModelRunner: ...@@ -73,7 +73,8 @@ class CPUModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int]]: ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Optional[torch.Tensor]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
...@@ -347,8 +348,8 @@ class CPUModelRunner: ...@@ -347,8 +348,8 @@ class CPUModelRunner:
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
SamplingMetadata]: Optional[torch.Tensor]]:
multi_modal_input = None multi_modal_input = None
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment