Unverified Commit df450aa5 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[Bugfix] Fix num_heads value for simple connector when tp enabled (#12074)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent bbe5f9de
...@@ -35,6 +35,7 @@ class SimpleConnector(KVConnectorBase): ...@@ -35,6 +35,7 @@ class SimpleConnector(KVConnectorBase):
): ):
self.config = config.kv_transfer_config self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size
if self.config.kv_connector == "PyNcclConnector": if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
...@@ -161,7 +162,7 @@ class SimpleConnector(KVConnectorBase): ...@@ -161,7 +162,7 @@ class SimpleConnector(KVConnectorBase):
end_layer = model_executable.model.end_layer end_layer = model_executable.model.end_layer
model_config = model_executable.model.config model_config = model_executable.model.config
num_heads = model_config.num_key_value_heads num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads) head_size = int(hidden_size / num_attention_heads)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment