Commit 8776c63c authored by 王敏's avatar 王敏
Browse files

[fix]修复开并行解码后,speculative-disable-by-batch-size设的比测试的batch小的话可能出现的数组越界问题

parent 8fc15e04
...@@ -1300,17 +1300,36 @@ class HiddenStates(msgspec.Struct, array_like=True, ...@@ -1300,17 +1300,36 @@ class HiddenStates(msgspec.Struct, array_like=True,
"""Update hidden states from target model invocation. Only used for """Update hidden states from target model invocation. Only used for
decode steps""" decode steps"""
assert len(seq_group_metadata_list) == len(hidden_states) assert len(seq_group_metadata_list) == len(hidden_states)
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) # self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
# self.hidden_states = torch.cat([self.hidden_states, hidden_states])
# if self.second_last_token_hidden_states is not None:
# # Adding dummy hidden_states to this to maintain same shape
# self.second_last_token_hidden_states = torch.cat([
# self.second_last_token_hidden_states,
# torch.zeros_like(hidden_states)
# if second_last_token_hidden_states is None else
# second_last_token_hidden_states
# ])
seq_ids = get_all_seq_ids(seq_group_metadata_list)
diff_seq_ids = [item for item in self._seq_ids if item not in seq_ids]
index = [self._seq_ids.index(seq_id) for seq_id in diff_seq_ids]
self._seq_ids = diff_seq_ids
self.hidden_states = self.hidden_states[index]
self.hidden_states = torch.cat([self.hidden_states, hidden_states]) self.hidden_states = torch.cat([self.hidden_states, hidden_states])
if self.second_last_token_hidden_states is not None: if self.second_last_token_hidden_states is not None:
# Adding dummy hidden_states to this to maintain same shape # Adding dummy hidden_states to this to maintain same shape
self.second_last_token_hidden_states = self.second_last_token_hidden_states[index]
self.second_last_token_hidden_states = torch.cat([ self.second_last_token_hidden_states = torch.cat([
self.second_last_token_hidden_states, self.second_last_token_hidden_states,
torch.zeros_like(hidden_states) torch.zeros_like(hidden_states)
if second_last_token_hidden_states is None else if second_last_token_hidden_states is None else
second_last_token_hidden_states second_last_token_hidden_states
]) ])
self._seq_ids.extend(seq_ids)
def prune(self, def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
......
...@@ -691,15 +691,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -691,15 +691,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
torch.where(sampler_output.sampled_token_ids - torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]] VLLM_INVALID_TOKEN_ID)[0]]
if not skip_proposer: if self.previous_hidden_states is None and len(
if self.previous_hidden_states is None and len( seq_group_meta_with_hidden):
seq_group_meta_with_hidden): self.previous_hidden_states = HiddenStates(
self.previous_hidden_states = HiddenStates( hidden_states, seq_group_meta_with_hidden)
hidden_states, seq_group_meta_with_hidden) elif self.previous_hidden_states and len(
elif self.previous_hidden_states and len( seq_group_meta_with_hidden):
seq_group_meta_with_hidden): self.previous_hidden_states.update(hidden_states,
self.previous_hidden_states.update(hidden_states, seq_group_meta_with_hidden)
seq_group_meta_with_hidden)
# Store logits from target model execution. # Store logits from target model execution.
if self.tree_decoding: if self.tree_decoding:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment