"examples/vscode:/vscode.git/clone" did not exist on "777e602b3a3abc62a02f0348d96de7280fb2e5b3"
Unverified Commit 2972b7ed authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: MLA disaggregation support to vLLM patch (#745)

parent 85d8d02d
diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py
index 54278f5f6..7eaf92feb 100644
--- a/vllm/attention/backends/mla/common.py
+++ b/vllm/attention/backends/mla/common.py
@@ -300,7 +300,8 @@ class MLACommonState(AttentionState, Generic[T]):
cache_config = runner.cache_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
- self.enable_prefix_caching = cache_config.enable_prefix_caching
+ # TODO ptarasiewicz: we pretend that prefix caching is enabled to make fetching from Decode kv cache work
+ self.enable_prefix_caching = True # cache_config.enable_prefix_caching
if self.chunked_prefill_enabled or self.enable_prefix_caching:
self.context_chunk_workspace_size = min(
@@ -735,8 +736,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
self.block_size = input_builder.block_size
self.chunked_prefill_enabled = \
self.runner.scheduler_config.chunked_prefill_enabled
- self.enable_prefix_caching = \
- self.runner.cache_config.enable_prefix_caching
+ # TODO ptarasiewicz: we pretend that prefix caching is enabled to make fetching from Decode kv cache work
+ self.enable_prefix_caching = True # self.runner.cache_config.enable_prefix_caching
if self.chunked_prefill_enabled or self.enable_prefix_caching:
attn_state = self.input_builder.runner.attn_state
diff --git a/vllm/config.py b/vllm/config.py diff --git a/vllm/config.py b/vllm/config.py
index 2912361ee..eea9cb65d 100644 index 2912361ee..eea9cb65d 100644
--- a/vllm/config.py --- a/vllm/config.py
...@@ -1014,10 +1039,10 @@ index 000000000..a2f9ce99e ...@@ -1014,10 +1039,10 @@ index 000000000..a2f9ce99e
\ No newline at end of file \ No newline at end of file
diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py
new file mode 100644 new file mode 100644
index 000000000..136a0bd37 index 000000000..bd4ac984e
--- /dev/null --- /dev/null
+++ b/vllm/distributed/device_communicators/nixl.py +++ b/vllm/distributed/device_communicators/nixl.py
@@ -0,0 +1,394 @@ @@ -0,0 +1,445 @@
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0 +# SPDX-License-Identifier: Apache-2.0
+# +#
...@@ -1060,7 +1085,7 @@ index 000000000..136a0bd37 ...@@ -1060,7 +1085,7 @@ index 000000000..136a0bd37
+ dict=True): + dict=True):
+ engine_id: str + engine_id: str
+ agent_metadata: List[bytes] + agent_metadata: List[bytes]
+ kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values + kv_caches_base_addr: List[List[List[int]]] # base address for each rank for each layer for keys and values
+ num_blocks: int + num_blocks: int
+ +
+ +
...@@ -1096,6 +1121,7 @@ index 000000000..136a0bd37 ...@@ -1096,6 +1121,7 @@ index 000000000..136a0bd37
+ +
+ +
+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size + self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
+ self._is_mla = "deepseek" in vllm_config.model_config.architectures[0].lower()
+ +
+ +
+ @property + @property
...@@ -1103,27 +1129,59 @@ index 000000000..136a0bd37 ...@@ -1103,27 +1129,59 @@ index 000000000..136a0bd37
+ return self.nixl_wrapper.name + return self.nixl_wrapper.name
+ +
+ def register_kv_caches(self, kv_caches: List[torch.Tensor]): + def register_kv_caches(self, kv_caches: List[torch.Tensor]):
+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape + logger.debug("--------------------------------")
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size() + logger.debug("Registering kv caches for engine %s", self.engine_id)
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) + logger.debug(f"Is deepseek: {self._is_mla}")
+ self.num_layers = len(kv_caches) + logger.debug(f"kv_cache shape: {kv_caches[0].shape}")
+ self.num_blocks = num_blocks + logger.debug("--------------------------------")
+ self.num_heads = num_heads +
+ self.kv_caches = kv_caches + if self._is_mla:
+ kv_caches_base_addr = [] + num_blocks, block_size, head_dim = kv_caches[0].shape
+ caches_data = [] + self.block_len = head_dim * block_size * kv_caches[0].element_size()
+ for key_cache, value_cache in kv_caches: + logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ base_addr = key_cache.data_ptr() + self.num_layers = len(kv_caches)
+ region_len = 2 * num_blocks * self.block_len + self.num_blocks = num_blocks
+ caches_data.append((base_addr, region_len, self.rank, "")) + self.num_heads = 1
+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr())) + self.kv_caches = kv_caches
+ + self.num_cache_entries = 1
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr +
+ + kv_caches_base_addr = []
+ descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") + caches_data = []
+ logger.debug("Registering descs: %s", caches_data) + for kv_cache in kv_caches:
+ self.nixl_wrapper.register_memory(descs) + base_addr = kv_cache.data_ptr()
+ self._registered_descs.append(descs) + region_len = self.num_cache_entries * num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank, ""))
+ kv_caches_base_addr.append([base_addr,])
+
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+ else:
+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
+ self.num_layers = len(kv_caches)
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.kv_caches = kv_caches
+ self.num_cache_entries = 2
+ kv_caches_base_addr = []
+ caches_data = []
+ for key_cache, value_cache in kv_caches:
+ base_addr = key_cache.data_ptr()
+ region_len = self.num_cache_entries * num_blocks * self.block_len
+ caches_data.append((base_addr, region_len, self.rank, ""))
+ kv_caches_base_addr.append([key_cache.data_ptr(), value_cache.data_ptr()])
+
+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
+
+ descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
+ logger.debug("Registering descs: %s", caches_data)
+ self.nixl_wrapper.register_memory(descs)
+ self._registered_descs.append(descs)
+ +
+ def get_agent_metadata(self): + def get_agent_metadata(self):
+ return self.nixl_wrapper.get_agent_metadata() + return self.nixl_wrapper.get_agent_metadata()
...@@ -1138,7 +1196,7 @@ index 000000000..136a0bd37 ...@@ -1138,7 +1196,7 @@ index 000000000..136a0bd37
+ self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) + self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle)
+ for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): + for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
+ for dst_xfer_side_handle in dst_xfer_side_handles.values(): + for dst_xfer_side_handle in dst_xfer_side_handles.values():
+ self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) + self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
+ +
+ def _get_ranges(self, block_ids): + def _get_ranges(self, block_ids):
+ # This function should return a list of ranges of block ids that are contiguous + # This function should return a list of ranges of block ids that are contiguous
...@@ -1167,20 +1225,23 @@ index 000000000..136a0bd37 ...@@ -1167,20 +1225,23 @@ index 000000000..136a0bd37
+ if i is not None: + if i is not None:
+ num_blocks = self.num_blocks + num_blocks = self.num_blocks
+ for layer_id in layer_ids: + for layer_id in layer_ids:
+ for is_value in [0, 1]: + for entry_index in range(self.num_cache_entries):
+ staging_range_idx = 0 + staging_range_idx = 0
+ for block_id in block_ids: + for block_id in block_ids:
+ if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]: + if staging_ranges is not None:
+ staging_range_idx += 1 + if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]:
+ start_offset = staging_ranges[staging_range_idx][0] + staging_range_idx += 1
+ i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1) + start_offset = staging_ranges[staging_range_idx][0]
+ descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset)) + i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1)
+ descs_ids.append(layer_id * self.num_cache_entries * num_blocks * tp_multiplier + entry_index * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
+ else:
+ descs_ids.append(layer_id * self.num_cache_entries * num_blocks + entry_index * num_blocks + block_id)
+ else: + else:
+ num_blocks = self.dst_num_blocks[engine_id] + num_blocks = self.dst_num_blocks[engine_id]
+ for layer_id in layer_ids: + for layer_id in layer_ids:
+ for is_value in [0, 1]: + for entry_index in range(self.num_cache_entries):
+ for block_id in block_ids: + for block_id in block_ids:
+ descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id) + descs_ids.append(layer_id * self.num_cache_entries * num_blocks + entry_index * num_blocks + block_id)
+ return descs_ids + return descs_ids
+ +
+ def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False): + def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False):
...@@ -1243,10 +1304,15 @@ index 000000000..136a0bd37 ...@@ -1243,10 +1304,15 @@ index 000000000..136a0bd37
+ +
+ start_time = time.perf_counter() + start_time = time.perf_counter()
+ +
+ local_ranges = self._get_ranges(local_block_ids) + if self._is_mla:
+ staging_ranges = self._get_ranges(staging_block_ids) + # TODO ptarasiewicz: we skip staging when is_mla is true, we shouldn't assign staging blocks at all
+ staging_rearranging_ranges = None
+ staging_block_ids = local_block_ids
+ else:
+ local_ranges = self._get_ranges(local_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ +
+ local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges) + local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
+ +
+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] + tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids) + remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
...@@ -1282,11 +1348,12 @@ index 000000000..136a0bd37 ...@@ -1282,11 +1348,12 @@ index 000000000..136a0bd37
+ +
+ rearrange_start_time = time.perf_counter() + rearrange_start_time = time.perf_counter()
+ +
+ for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): + if not self._is_mla:
+ logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range) + for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
+ for kv_cache in self.kv_caches: + logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
+ for cache in kv_cache: + for kv_cache in self.kv_caches:
+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read") + for cache in kv_cache:
+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read")
+ +
+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000) + logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000)
+ logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000) + logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000)
...@@ -1311,16 +1378,21 @@ index 000000000..136a0bd37 ...@@ -1311,16 +1378,21 @@ index 000000000..136a0bd37
+ +
+ start_time = time.perf_counter() + start_time = time.perf_counter()
+ +
+ local_ranges = self._get_ranges(local_block_ids) + if self._is_mla:
+ staging_ranges = self._get_ranges(staging_block_ids) + # TODO ptarasiewicz: we skip staging when is_mla is true, we shouldn't assign staging blocks at all
+ staging_rearranging_ranges = None
+ staging_block_ids = local_block_ids
+ else:
+ local_ranges = self._get_ranges(local_block_ids)
+ staging_ranges = self._get_ranges(staging_block_ids)
+ +
+ local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges) + local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
+ +
+ for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): + for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
+ logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range) + logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
+ for kv_cache in self.kv_caches: + for kv_cache in self.kv_caches:
+ for cache in kv_cache: + for cache in kv_cache:
+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write") + rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write")
+ +
+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000) + logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
+ +
...@@ -1330,6 +1402,7 @@ index 000000000..136a0bd37 ...@@ -1330,6 +1402,7 @@ index 000000000..136a0bd37
+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids) + remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
+ local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] + local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
+ +
+ logger.debug("Creating xfer handles")
+ for i in range(tp_multiplier): + for i in range(tp_multiplier):
+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges) + staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
+ assert len(staging_block_descs_ids) == len(remote_block_descs_ids) + assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
...@@ -1364,7 +1437,10 @@ index 000000000..136a0bd37 ...@@ -1364,7 +1437,10 @@ index 000000000..136a0bd37
+ assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}" + assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}"
+ +
+ logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier) + logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier)
+ dst_block_len = self.block_len // tp_multiplier + if self._is_mla:
+ dst_block_len = self.block_len
+ else:
+ dst_block_len = self.block_len // tp_multiplier
+ if tp_multiplier not in self.src_xfer_side_handles: + if tp_multiplier not in self.src_xfer_side_handles:
+ # create descs and xfer side handles + # create descs and xfer side handles
+ blocks_data = [] + blocks_data = []
...@@ -1372,7 +1448,7 @@ index 000000000..136a0bd37 ...@@ -1372,7 +1448,7 @@ index 000000000..136a0bd37
+ for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]: + for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]:
+ for block_id in range(self.num_blocks): + for block_id in range(self.num_blocks):
+ block_offset = block_id * self.block_len + block_offset = block_id * self.block_len
+ for i in range(tp_multiplier): + for i in range(1 if self._is_mla else tp_multiplier):
+ tp_multiplier_offset = i * dst_block_len + tp_multiplier_offset = i * dst_block_len
+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank)) + blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i) + logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment