Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
9a7b506c
Commit
9a7b506c
authored
Feb 17, 2025
by
ptarasiewiczNV
Committed by
GitHub
Feb 17, 2025
Browse files
feat: disaggregation support for deepseek arch models (#196)
parent
fd0bcfa2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
102 additions
and
23 deletions
+102
-23
container/deps/vllm/vllm_v0.7.2.patch
container/deps/vllm/vllm_v0.7.2.patch
+102
-23
No files found.
container/deps/vllm/vllm_v0.7.2.patch
View file @
9a7b506c
...
...
@@ -74,7 +74,7 @@ index fe480533..b768e03c 100644
# Register various connectors here.
diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
index 2033e976..
71cd0567
100644
index 2033e976..
e33919c1
100644
--- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py
@@ -8,13 +8,15 @@
MooncakePipe.
...
...
@@ -94,7 +94,7 @@ index 2033e976..71cd0567 100644
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.logger import init_logger
@@ -33,
10
+35,
10
@@
class SimpleConnector(KVConnectorBase):
@@ -33,
6
+35,
7
@@
class SimpleConnector(KVConnectorBase):
rank: int,
local_rank: int,
config: VllmConfig,
...
...
@@ -102,11 +102,7 @@ index 2033e976..71cd0567 100644
):
self.config = config.kv_transfer_config
- self.tp_size = config.parallel_config.tensor_parallel_size
if self.config.kv_connector == "PyNcclConnector":
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
@@ -71,20 +73,31 @@
class SimpleConnector(KVConnectorBase):
@@ -71,20 +74,31 @@
class SimpleConnector(KVConnectorBase):
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
...
...
@@ -139,7 +135,7 @@ index 2033e976..71cd0567 100644
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -108,11 +12
1
,13 @@
class SimpleConnector(KVConnectorBase):
@@ -108,11 +12
2
,13 @@
class SimpleConnector(KVConnectorBase):
# its recv pipe to the send pipe of KV producder
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
...
...
@@ -153,7 +149,7 @@ index 2033e976..71cd0567 100644
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@@ -131,21 +14
6
,25 @@
class SimpleConnector(KVConnectorBase):
@@ -131,21 +14
7
,25 @@
class SimpleConnector(KVConnectorBase):
self.config.kv_buffer_size,
)
...
...
@@ -183,15 +179,32 @@ index 2033e976..71cd0567 100644
def send_kv_caches_and_hidden_states(
self,
@@ -161,
6
+18
0,7
@@
class SimpleConnector(KVConnectorBase):
@@ -161,
12
+18
1,20
@@
class SimpleConnector(KVConnectorBase):
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
+ request_ids = list(model_input.request_ids_to_seq_ids.keys())
model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
@@ -175,27 +195,36 @@
class SimpleConnector(KVConnectorBase):
- num_heads = int(model_config.num_key_value_heads / self.tp_size)
- hidden_size = model_config.hidden_size
- num_attention_heads = model_config.num_attention_heads
- head_size = int(hidden_size / num_attention_heads)
+ is_deepseek = "deepseek" in model_config.architectures[0].lower()
+ if not is_deepseek:
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
+ hidden_size = model_config.hidden_size
+ num_attention_heads = model_config.num_attention_heads
+ head_size = int(hidden_size / num_attention_heads)
+ else:
+ num_heads = int(model_config.num_key_value_heads / self.tp_size)
+ hidden_size = model_config.hidden_size
+ num_attention_heads = model_config.num_attention_heads
+ head_size = int(4.5 * hidden_size / num_attention_heads)
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
@@ -175,27 +203,40 @@
class SimpleConnector(KVConnectorBase):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
...
...
@@ -200,17 +213,14 @@ index 2033e976..71cd0567 100644
+ starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config)
+
+ for target_rank in range(self.config.tensor_parallel_multiplier):
+
+ keys, values = [], []
- keys, values = [], []
+ for layer_id in range(start_layer, end_layer):
+ kv_cache = kv_caches[layer_id - start_layer]
+ keys, values = [], []
- for layer_id in range(start_layer, end_layer):
- kv_cache = kv_caches[layer_id - start_layer]
+
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
+ v
alue
_cache = kv_cache
[1].reshape(-1, num_heads, head_size)
+
for layer_id in range(start_layer, end_layer):
+
k
v_cache = kv_cache
s[layer_id - start_layer]
- key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
- value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
...
...
@@ -223,8 +233,15 @@ index 2033e976..71cd0567 100644
- keys.append(key_cache[current_slot_mapping].unsqueeze(0))
- values.append(value_cache[current_slot_mapping].unsqueeze(0))
+ keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+ values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+ if not is_deepseek:
+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
+ keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+ values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
+ else:
+ key_cache = kv_cache
+ keys.append(key_cache[current_slot_mapping].unsqueeze(0))
+ values.append(torch.empty(0))
- keys = torch.cat(keys, dim=0)
- values = torch.cat(values, dim=0)
...
...
@@ -242,7 +259,7 @@ index 2033e976..71cd0567 100644
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
@@ -215,6 +2
44
,7 @@
class SimpleConnector(KVConnectorBase):
@@ -215,6 +2
56
,7 @@
class SimpleConnector(KVConnectorBase):
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
...
...
@@ -250,7 +267,17 @@ index 2033e976..71cd0567 100644
hidden_or_intermediate_states_for_one_req = []
@@ -229,13 +259,15 @@
class SimpleConnector(KVConnectorBase):
@@ -222,6 +264,9 @@
class SimpleConnector(KVConnectorBase):
num_computed_tokens_list = []
start_pos_list = []
+ model_config = model_executable.model.config
+ is_deepseek = "deepseek" in model_config.architectures[0].lower()
+
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
@@ -229,13 +274,15 @@
class SimpleConnector(KVConnectorBase):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
...
...
@@ -267,7 +294,46 @@ index 2033e976..71cd0567 100644
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
@@ -312,3 +344,77 @@
class SimpleConnector(KVConnectorBase):
@@ -267,19 +314,25 @@
class SimpleConnector(KVConnectorBase):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
- key_cache, value_cache = kv_cache[0], kv_cache[1]
- ops.reshape_and_cache_flash(
- keys[i - model_executable.model.start_layer].to(
- key_cache.device),
- values[i - model_executable.model.start_layer].to(
- value_cache.device),
- key_cache,
- value_cache,
- slot_mapping[start_pos:end_pos],
- layer.self_attn.attn.kv_cache_dtype,
- layer.self_attn.attn._k_scale,
- layer.self_attn.attn._v_scale,
- )
+ if not is_deepseek:
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
+ ops.reshape_and_cache_flash(
+ keys[i - model_executable.model.start_layer].to(
+ key_cache.device),
+ values[i - model_executable.model.start_layer].to(
+ value_cache.device),
+ key_cache,
+ value_cache,
+ slot_mapping[start_pos:end_pos],
+ layer.self_attn.attn.kv_cache_dtype,
+ layer.self_attn.attn._k_scale,
+ layer.self_attn.attn._v_scale,
+ )
+ else:
+ key_cache = kv_cache
+ copy_from =keys[i - model_executable.model.start_layer].to(
+ key_cache.device)
+ kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
hidden_or_intermediate_states_for_one_req.append(hidden)
@@ -312,3 +365,77 @@
class SimpleConnector(KVConnectorBase):
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
# close the data_pipe.
pass
...
...
@@ -823,3 +889,16 @@ index 321902d1..b8937ef8 100644
def ensure_model_parallel_initialized(
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 773f5abe..3eefd266 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -585,6 +585,8 @@
class DeepseekV2Model(nn.Module):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
+ self.config = config
+
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment