Unverified Commit 980385f8 authored by Mathis Felardos's avatar Mathis Felardos Committed by GitHub
Browse files

[Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and...


[Bugfix][Disaggregated] Add a check in send_kv_caches_and_hidden_states and fix the reshape of the KVCache (#14369)
Signed-off-by: default avatarMathis Felardos <mathis@mistral.ai>
parent ca7a2d5f
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
""" """
Simple KV Cache Connector for Distributed Machine Learning Inference Simple KV Cache Connector for Distributed Machine Learning Inference
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
MooncakePipe. MooncakePipe.
...@@ -159,6 +159,7 @@ class SimpleConnector(KVConnectorBase): ...@@ -159,6 +159,7 @@ class SimpleConnector(KVConnectorBase):
input_tokens_tensor = model_input.input_tokens input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens
start_layer = model_executable.model.start_layer start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer end_layer = model_executable.model.end_layer
...@@ -166,7 +167,8 @@ class SimpleConnector(KVConnectorBase): ...@@ -166,7 +167,8 @@ class SimpleConnector(KVConnectorBase):
num_heads = int(model_config.num_key_value_heads / self.tp_size) num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads) head_size = getattr(model_config, "head_dim",
int(hidden_size // num_attention_heads))
# query_lens contains new KV caches that are added to vLLM. # query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance # so we will send them to decode instance
...@@ -174,6 +176,15 @@ class SimpleConnector(KVConnectorBase): ...@@ -174,6 +176,15 @@ class SimpleConnector(KVConnectorBase):
for idx, slen in enumerate(seq_lens): for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx]) start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen end_pos = start_pos + slen
if start_pos >= num_prefill_tokens:
# vllm/worker/model_runner.py::_prepare_model_input_tensors:
# - input_tokens[:num_prefill_tokens] contains prefill tokens.
# - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You have some decode requests while using "
"SimpleConnector. Their KVCache won't be sent.")
break
current_tokens = input_tokens_tensor[start_pos:end_pos] current_tokens = input_tokens_tensor[start_pos:end_pos]
keys, values = [], [] keys, values = [], []
...@@ -236,7 +247,7 @@ class SimpleConnector(KVConnectorBase): ...@@ -236,7 +247,7 @@ class SimpleConnector(KVConnectorBase):
# - input_tokens[num_prefill_tokens:] contains decode tokens. # - input_tokens[num_prefill_tokens:] contains decode tokens.
logger.warning("You should set --enable_chunked_prefill=False " logger.warning("You should set --enable_chunked_prefill=False "
"and --max_num_batched_tokens " "and --max_num_batched_tokens "
"should be equal to max_seq_len_to_capture") "should be equal to --max_seq_len_to_capture")
bypass_model_exec = False bypass_model_exec = False
assert start_pos == num_prefill_tokens assert start_pos == num_prefill_tokens
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment