Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
2972b7ed
"examples/vscode:/vscode.git/clone" did not exist on "777e602b3a3abc62a02f0348d96de7280fb2e5b3"
Unverified
Commit
2972b7ed
authored
Apr 21, 2025
by
ptarasiewiczNV
Committed by
GitHub
Apr 21, 2025
Browse files
feat: MLA disaggregation support to vLLM patch (#745)
parent
85d8d02d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
128 additions
and
52 deletions
+128
-52
container/deps/vllm/vllm_v0.8.4-dynamo-kv-disagg-patch.patch
container/deps/vllm/vllm_v0.8.4-dynamo-kv-disagg-patch.patch
+128
-52
No files found.
container/deps/vllm/vllm_v0.8.4-dynamo-kv-disagg-patch.patch
View file @
2972b7ed
diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py
index 54278f5f6..7eaf92feb 100644
--- a/vllm/attention/backends/mla/common.py
+++ b/vllm/attention/backends/mla/common.py
@@ -300,7 +300,8 @@
class MLACommonState(AttentionState, Generic[T]):
cache_config = runner.cache_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
- self.enable_prefix_caching = cache_config.enable_prefix_caching
+ # TODO ptarasiewicz: we pretend that prefix caching is enabled to make fetching from Decode kv cache work
+ self.enable_prefix_caching = True # cache_config.enable_prefix_caching
if self.chunked_prefill_enabled or self.enable_prefix_caching:
self.context_chunk_workspace_size = min(
@@ -735,8 +736,8 @@
class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
self.block_size = input_builder.block_size
self.chunked_prefill_enabled = \
self.runner.scheduler_config.chunked_prefill_enabled
- self.enable_prefix_caching = \
- self.runner.cache_config.enable_prefix_caching
+ # TODO ptarasiewicz: we pretend that prefix caching is enabled to make fetching from Decode kv cache work
+ self.enable_prefix_caching = True # self.runner.cache_config.enable_prefix_caching
if self.chunked_prefill_enabled or self.enable_prefix_caching:
attn_state = self.input_builder.runner.attn_state
diff --git a/vllm/config.py b/vllm/config.py
diff --git a/vllm/config.py b/vllm/config.py
index 2912361ee..eea9cb65d 100644
index 2912361ee..eea9cb65d 100644
--- a/vllm/config.py
--- a/vllm/config.py
...
@@ -1014,10 +1039,10 @@ index 000000000..a2f9ce99e
...
@@ -1014,10 +1039,10 @@ index 000000000..a2f9ce99e
\
No newline at end of file
\
No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644
new file mode 100644
index 000000000..
136a0bd37
index 000000000..
bd4ac984e
--- /dev/null
--- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py
+++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,
394
@@
@@ -0,0 +1,
445
@@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-License-Identifier: Apache-2.0
+#
+#
...
@@ -1060,7 +1085,7 @@ index 000000000..136a0bd37
...
@@ -1060,7 +1085,7 @@ index 000000000..136a0bd37
+ dict=True):
+ dict=True):
+ engine_id: str
+ engine_id: str
+ agent_metadata: List[bytes]
+ agent_metadata: List[bytes]
+ kv_caches_base_addr: List[List[
Tuple[int,
int]]] # base address for each rank for each layer for keys and values
+ kv_caches_base_addr: List[List[
List[
int]]] # base address for each rank for each layer for keys and values
+ num_blocks: int
+ num_blocks: int
+
+
+
+
...
@@ -1096,6 +1121,7 @@ index 000000000..136a0bd37
...
@@ -1096,6 +1121,7 @@ index 000000000..136a0bd37
+
+
+
+
+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
+ self._is_mla = "deepseek" in vllm_config.model_config.architectures[0].lower()
+
+
+
+
+ @property
+ @property
...
@@ -1103,27 +1129,59 @@ index 000000000..136a0bd37
...
@@ -1103,27 +1129,59 @@ index 000000000..136a0bd37
+ return self.nixl_wrapper.name
+ return self.nixl_wrapper.name
+
+
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
+ logger.debug("--------------------------------")
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Registering kv caches for engine %s", self.engine_id)
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ logger.debug(f"Is deepseek: {self._is_mla}")
+ self.num_layers = len(kv_caches)
+ logger.debug(f"kv_cache shape: {kv_caches[0].shape}")
+ self.num_blocks = num_blocks
+ logger.debug("--------------------------------")
+ self.num_heads = num_heads
+
+ self.kv_caches = kv_caches
+ if self._is_mla:
+ kv_caches_base_addr = []
+ num_blocks, block_size, head_dim = kv_caches[0].shape
+ caches_data = []
+ self.block_len = head_dim * block_size * kv_caches[0].element_size()
+ for key_cache, value_cache in kv_caches:
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ base_addr = key_cache.data_ptr()
+ self.num_layers = len(kv_caches)
+ region_len = 2 * num_blocks * self.block_len
+ self.num_blocks = num_blocks
+ caches_data.append((base_addr, region_len, self.rank, ""))
+ self.num_heads = 1
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
+ self.kv_caches = kv_caches
+
+ self.num_cache_entries = 1
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+
+ kv_caches_base_addr = []
+ descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
+ caches_data = []
+ logger.debug("Registering descs: %s", caches_data)
+ for kv_cache in kv_caches:
+ self.nixl_wrapper.register_memory(descs)
+ base_addr = kv_cache.data_ptr()
+ self._registered_descs.append(descs)
+ region_len = self.num_cache_entries * num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank, ""))
+ kv_caches_base_addr.append([base_addr,])
+
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+ else:
+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ self.num_layers = len(kv_caches)
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.kv_caches = kv_caches
+ self.num_cache_entries = 2
+ kv_caches_base_addr = []
+ caches_data = []
+ for key_cache, value_cache in kv_caches:
+ base_addr = key_cache.data_ptr()
+ region_len = self.num_cache_entries * num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank, ""))
+ kv_caches_base_addr.append([key_cache.data_ptr(), value_cache.data_ptr()])
+
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+
+
+ def get_agent_metadata(self):
+ def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata()
+ return self.nixl_wrapper.get_agent_metadata()
...
@@ -1138,7 +1196,7 @@ index 000000000..136a0bd37
...
@@ -1138,7 +1196,7 @@ index 000000000..136a0bd37
+ self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle)
+ self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle)
+ for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
+ for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
+ for dst_xfer_side_handle in dst_xfer_side_handles.values():
+ for dst_xfer_side_handle in dst_xfer_side_handles.values():
+ self.nixl_wrapper.
d
ele
te_xfer_sid
e(dst_xfer_side_handle)
+ self.nixl_wrapper.
r
ele
ase_dlist_handl
e(dst_xfer_side_handle)
+
+
+ def _get_ranges(self, block_ids):
+ def _get_ranges(self, block_ids):
+ # This function should return a list of ranges of block ids that are contiguous
+ # This function should return a list of ranges of block ids that are contiguous
...
@@ -1167,20 +1225,23 @@ index 000000000..136a0bd37
...
@@ -1167,20 +1225,23 @@ index 000000000..136a0bd37
+ if i is not None:
+ if i is not None:
+ num_blocks = self.num_blocks
+ num_blocks = self.num_blocks
+ for layer_id in layer_ids:
+ for layer_id in layer_ids:
+ for
is_value in [0, 1]
:
+ for
entry_index in range(self.num_cache_entries)
:
+ staging_range_idx = 0
+ staging_range_idx = 0
+ for block_id in block_ids:
+ for block_id in block_ids:
+ if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]:
+ if staging_ranges is not None:
+ staging_range_idx += 1
+ if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]:
+ start_offset = staging_ranges[staging_range_idx][0]
+ staging_range_idx += 1
+ i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1)
+ start_offset = staging_ranges[staging_range_idx][0]
+ descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
+ i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1)
+ descs_ids.append(layer_id * self.num_cache_entries * num_blocks * tp_multiplier + entry_index * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
+ else:
+ descs_ids.append(layer_id * self.num_cache_entries * num_blocks + entry_index * num_blocks + block_id)
+ else:
+ else:
+ num_blocks = self.dst_num_blocks[engine_id]
+ num_blocks = self.dst_num_blocks[engine_id]
+ for layer_id in layer_ids:
+ for layer_id in layer_ids:
+ for
is_value in [0, 1]
:
+ for
entry_index in range(self.num_cache_entries)
:
+ for block_id in block_ids:
+ for block_id in block_ids:
+ descs_ids.append(layer_id *
2
* num_blocks +
is_value
* num_blocks + block_id)
+ descs_ids.append(layer_id *
self.num_cache_entries
* num_blocks +
entry_index
* num_blocks + block_id)
+ return descs_ids
+ return descs_ids
+
+
+ def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False):
+ def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False):
...
@@ -1243,10 +1304,15 @@ index 000000000..136a0bd37
...
@@ -1243,10 +1304,15 @@ index 000000000..136a0bd37
+
+
+ start_time = time.perf_counter()
+ start_time = time.perf_counter()
+
+
+ local_ranges = self._get_ranges(local_block_ids)
+ if self._is_mla:
+ staging_ranges = self._get_ranges(staging_block_ids)
+ # TODO ptarasiewicz: we skip staging when is_mla is true, we shouldn't assign staging blocks at all
+ staging_rearranging_ranges = None
+ staging_block_ids = local_block_ids
+ else:
+ local_ranges = self._get_ranges(local_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+
+
+ local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
+
local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
+
+
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
...
@@ -1282,11 +1348,12 @@ index 000000000..136a0bd37
...
@@ -1282,11 +1348,12 @@ index 000000000..136a0bd37
+
+
+ rearrange_start_time = time.perf_counter()
+ rearrange_start_time = time.perf_counter()
+
+
+ for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
+ if not self._is_mla:
+ logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
+ for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
+ for kv_cache in self.kv_caches:
+ logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
+ for cache in kv_cache:
+ for kv_cache in self.kv_caches:
+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read")
+ for cache in kv_cache:
+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read")
+
+
+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000)
+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000)
+ logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000)
...
@@ -1311,16 +1378,21 @@ index 000000000..136a0bd37
...
@@ -1311,16 +1378,21 @@ index 000000000..136a0bd37
+
+
+ start_time = time.perf_counter()
+ start_time = time.perf_counter()
+
+
+ local_ranges = self._get_ranges(local_block_ids)
+ if self._is_mla:
+ staging_ranges = self._get_ranges(staging_block_ids)
+ # TODO ptarasiewicz: we skip staging when is_mla is true, we shouldn't assign staging blocks at all
+ staging_rearranging_ranges = None
+ staging_block_ids = local_block_ids
+ else:
+ local_ranges = self._get_ranges(local_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+
+
+ local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
+
local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
+
+
+ for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
+
for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
+ logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
+
logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
+ for kv_cache in self.kv_caches:
+
for kv_cache in self.kv_caches:
+ for cache in kv_cache:
+
for cache in kv_cache:
+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write")
+
rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write")
+
+
+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
+
+
...
@@ -1330,6 +1402,7 @@ index 000000000..136a0bd37
...
@@ -1330,6 +1402,7 @@ index 000000000..136a0bd37
+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
+ local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
+ local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
+
+
+ logger.debug("Creating xfer handles")
+ for i in range(tp_multiplier):
+ for i in range(tp_multiplier):
+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
+ assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
+ assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
...
@@ -1364,7 +1437,10 @@ index 000000000..136a0bd37
...
@@ -1364,7 +1437,10 @@ index 000000000..136a0bd37
+ assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}"
+ assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}"
+
+
+ logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier)
+ logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier)
+ dst_block_len = self.block_len // tp_multiplier
+ if self._is_mla:
+ dst_block_len = self.block_len
+ else:
+ dst_block_len = self.block_len // tp_multiplier
+ if tp_multiplier not in self.src_xfer_side_handles:
+ if tp_multiplier not in self.src_xfer_side_handles:
+ # create descs and xfer side handles
+ # create descs and xfer side handles
+ blocks_data = []
+ blocks_data = []
...
@@ -1372,7 +1448,7 @@ index 000000000..136a0bd37
...
@@ -1372,7 +1448,7 @@ index 000000000..136a0bd37
+ for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]:
+ for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]:
+ for block_id in range(self.num_blocks):
+ for block_id in range(self.num_blocks):
+ block_offset = block_id * self.block_len
+ block_offset = block_id * self.block_len
+ for i in range(tp_multiplier):
+ for i in range(
1 if self._is_mla else
tp_multiplier):
+ tp_multiplier_offset = i * dst_block_len
+ tp_multiplier_offset = i * dst_block_len
+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment