"vllm/vscode:/vscode.git/clone" did not exist on "e2df354493d974c836467f04a4a6e489d20b3d1d"
Commit ba2ca2db authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.7.2-dev-wm' into 'v0.7.2-dev'

[fix]修复开并行解码后,speculative-disable-by-batch-size设的比测试的batch小的话可能出现的数组越界问题

See merge request dcutoolkit/deeplearing/vllm!99
parents 8fc15e04 8776c63c
......@@ -1300,17 +1300,36 @@ class HiddenStates(msgspec.Struct, array_like=True,
"""Update hidden states from target model invocation. Only used for
decode steps"""
assert len(seq_group_metadata_list) == len(hidden_states)
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
# self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
# self.hidden_states = torch.cat([self.hidden_states, hidden_states])
# if self.second_last_token_hidden_states is not None:
# # Adding dummy hidden_states to this to maintain same shape
# self.second_last_token_hidden_states = torch.cat([
# self.second_last_token_hidden_states,
# torch.zeros_like(hidden_states)
# if second_last_token_hidden_states is None else
# second_last_token_hidden_states
# ])
seq_ids = get_all_seq_ids(seq_group_metadata_list)
diff_seq_ids = [item for item in self._seq_ids if item not in seq_ids]
index = [self._seq_ids.index(seq_id) for seq_id in diff_seq_ids]
self._seq_ids = diff_seq_ids
self.hidden_states = self.hidden_states[index]
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
if self.second_last_token_hidden_states is not None:
# Adding dummy hidden_states to this to maintain same shape
self.second_last_token_hidden_states = self.second_last_token_hidden_states[index]
self.second_last_token_hidden_states = torch.cat([
self.second_last_token_hidden_states,
torch.zeros_like(hidden_states)
if second_last_token_hidden_states is None else
second_last_token_hidden_states
])
self._seq_ids.extend(seq_ids)
def prune(self,
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
......
......@@ -691,7 +691,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
torch.where(sampler_output.sampled_token_ids -
VLLM_INVALID_TOKEN_ID)[0]]
if not skip_proposer:
if self.previous_hidden_states is None and len(
seq_group_meta_with_hidden):
self.previous_hidden_states = HiddenStates(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment