"vscode:/vscode.git/clone" did not exist on "1e636721bc765a4bfdda4512d9ad6f39c1a1a225"
Unverified Commit 980385f8 authored by Mathis Felardos's avatar Mathis Felardos Committed by GitHub
Browse files

[Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and...


[Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and fix the reshape of the KVCache (#14369)
Signed-off-by: default avatarMathis Felardos <mathis@mistral.ai>
parent ca7a2d5f
...@@ -159,6 +159,7 @@ class SimpleConnector(KVConnectorBase): ...@@ -159,6 +159,7 @@ class SimpleConnector(KVConnectorBase):
input_tokens_tensor = model_input.input_tokens input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer end_layer = model_executable.model.end_layer
...@@ -166,7 +167,8 @@ class SimpleConnector(KVConnectorBase): ...@@ -166,7 +167,8 @@ class SimpleConnector(KVConnectorBase):
num_heads = int(model_config.num_key_value_heads / self.tp_size) num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads) head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))
# query_lens contains new KV caches that are added to vLLM. # query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance # so we will send them to decode instance
...@@ -174,6 +176,15 @@ class SimpleConnector(KVConnectorBase): ...@@ -174,6 +176,15 @@ class SimpleConnector(KVConnectorBase):
for idx, slen in enumerate(seq_lens): for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx]) start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You have some decode requests while using "
"SimpleConnector. Their KVCache won't be sent.")
break
current_tokens = input_tokens_tensor[start_pos:end_pos] current_tokens = input_tokens_tensor[start_pos:end_pos]
keys, values = [], [] keys, values = [], []
...@@ -236,7 +247,7 @@ class SimpleConnector(KVConnectorBase): ...@@ -236,7 +247,7 @@ class SimpleConnector(KVConnectorBase):
# - input_tokens[num_prefill_tokens:] contains decode tokens. # - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You should set --enable_chunked_prefill=False " logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens " "and --max_num_batched_tokens "
"should be equal to max_seq_len_to_capture") "should be equal to --max_seq_len_to_capture")
bypass_model_exec = False bypass_model_exec = False
assert start_pos == num_prefill_tokens assert start_pos == num_prefill_tokens
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment