Commit 8776c63c authored by 王敏's avatar 王敏
Browse files

[fix]修复开并行解码后,speculative-disable-by-batch-size设的比测试的batch小的话可能出现的数组越界问题

parent 8fc15e04
......@@ -1300,17 +1300,36 @@ class HiddenStates(msgspec.Struct, array_like=True,
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert len(seq_group_metadata_list) == len(hidden_states)
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
# self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
# self.hidden_states = torch.cat([self.hidden_states, hidden_states])
# if self.second_last_token_hidden_states is not None:
# # Adding dummy hidden_states to this to maintain same shape
# self.second_last_token_hidden_states = torch.cat([
# self.second_last_token_hidden_states,
# torch.zeros_like(hidden_states)
# if second_last_token_hidden_states is None else
# second_last_token_hidden_states
# ])
seq_ids = get_all_seq_ids(seq_group_metadata_list)
diff_seq_ids = [item for item in self._seq_ids if item not in seq_ids]
index = [self._seq_ids.index(seq_id) for seq_id in diff_seq_ids]
self._seq_ids = diff_seq_ids
self.hidden_states = self.hidden_states[index]
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
if self.second_last_token_hidden_states is not None:
# Adding dummy hidden_states to this to maintain same shape
self.second_last_token_hidden_states = self.second_last_token_hidden_states[index]
self.second_last_token_hidden_states = torch.cat([
self.second_last_token_hidden_states,
torch.zeros_like(hidden_states)
if second_last_token_hidden_states is None else
second_last_token_hidden_states
])
self._seq_ids.extend(seq_ids)
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
......
......@@ -691,15 +691,14 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]]
if not skip_proposer:
if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_meta_with_hidden)
elif self.previous_hidden_states and len(
seq_group_meta_with_hidden):
self.previous_hidden_states.update(hidden_states,
seq_group_meta_with_hidden)
if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates(
hidden_states, seq_group_meta_with_hidden)
elif self.previous_hidden_states and len(
seq_group_meta_with_hidden):
self.previous_hidden_states.update(hidden_states,
seq_group_meta_with_hidden)
# Store logits from target model execution.
if self.tree_decoding:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment