Commit ba2ca2db authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev-wm' into 'v0.7.2-dev'

[fix]修复开并行解码后,speculative-disable-by-batch-size设的比测试的batch小的话可能出现的数组越界问题

See merge request dcutoolkit/deeplearing/vllm!99
parents 8fc15e04 8776c63c
...@@ -1300,17 +1300,36 @@ class HiddenStates(msgspec.Struct, array_like=True, ...@@ -1300,17 +1300,36 @@ class HiddenStates(msgspec.Struct, array_like=True,
"""Update hidden states from target model invocation. Only used for """Update hidden states from target model invocation. Only used for
decode steps""" decode steps"""
assert len(seq_group_metadata_list) == len(hidden_states) assert len(seq_group_metadata_list) == len(hidden_states)
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) # self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
# self.hidden_states = torch.cat([self.hidden_states, hidden_states])
# if self.second_last_token_hidden_states is not None:
# # Adding dummy hidden_states to this to maintain same shape
# self.second_last_token_hidden_states = torch.cat([
# self.second_last_token_hidden_states,
# torch.zeros_like(hidden_states)
# if second_last_token_hidden_states is None else
# second_last_token_hidden_states
# ])
seq_ids = get_all_seq_ids(seq_group_metadata_list)
diff_seq_ids = [item for item in self._seq_ids if item not in seq_ids]
index = [self._seq_ids.index(seq_id) for seq_id in diff_seq_ids]
self._seq_ids = diff_seq_ids
self.hidden_states = self.hidden_states[index]
self.hidden_states = torch.cat([self.hidden_states, hidden_states]) self.hidden_states = torch.cat([self.hidden_states, hidden_states])
if self.second_last_token_hidden_states is not None: if self.second_last_token_hidden_states is not None:
# Adding dummy hidden_states to this to maintain same shape # Adding dummy hidden_states to this to maintain same shape
self.second_last_token_hidden_states = self.second_last_token_hidden_states[index]
self.second_last_token_hidden_states = torch.cat([ self.second_last_token_hidden_states = torch.cat([
self.second_last_token_hidden_states, self.second_last_token_hidden_states,
torch.zeros_like(hidden_states) torch.zeros_like(hidden_states)
if second_last_token_hidden_states is None else if second_last_token_hidden_states is None else
second_last_token_hidden_states second_last_token_hidden_states
]) ])
self._seq_ids.extend(seq_ids)
def prune(self, def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
......
...@@ -691,7 +691,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): ...@@ -691,7 +691,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
torch.where(sampler_output.sampled_token_ids - torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]] VLLM_INVALID_TOKEN_ID)[0]]
if not skip_proposer:
if self.previous_hidden_states is None and len( if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden): seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates( self.previous_hidden_states = HiddenStates(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment