diff --git a/vllm/config.py b/vllm/config.py index 9ba49757..7e871521 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2629,7 +2629,7 @@ class KVTransferConfig(BaseModel): kv_buffer_size: float = 1e9 # Whether this vLLM instance produces, consumes KV cache, or both. Choices - # are 'kv_producer', 'kv_consumer', and 'both'. + # are 'kv_producer', 'kv_consumer', and 'kv_both'. kv_role: Optional[str] = None # The rank of this vLLM instance in the KV cache transfer. Typical value: @@ -2647,6 +2647,14 @@ class KVTransferConfig(BaseModel): # The KV connector port, used to build distributed connection kv_port: int = 14579 + + # This does not need to be set by the user. It is set by the connector. + kv_producers_parallel_size: Optional[int] = None + kv_producers_tensor_parallel_size: Optional[int] = None + kv_producers_pipeline_parallel_size: Optional[int] = None + kv_consumers_tensor_parallel_size: Optional[int] = None + kv_consumers_pipeline_parallel_size: Optional[int] = None + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2685,6 +2693,7 @@ class KVTransferConfig(BaseModel): "is set, supported roles are `kv_producer`, " "`kv_consumer`, and `kv_both`") + @property def is_kv_transfer_instance(self) -> bool: return self.kv_connector is not None and \ @@ -2706,6 +2715,18 @@ class KVTransferConfig(BaseModel): return self.kv_connector is not None and \ self.kv_role in ["kv_consumer", "kv_both"] + @property + def tensor_parallel_multiplier(self) -> int: + return self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size + + @property + def kv_consumers_parallel_size(self) -> int: + return self.kv_parallel_size - self.kv_producers_parallel_size + + @property + def kv_world_size(self) -> int: + return self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier + class CompilationLevel: # constants for the levels of the compilation process diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index fe480533..b768e03c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -27,13 +27,13 @@ class KVConnectorFactory: @classmethod def create_connector(cls, rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: + config: "VllmConfig", world_group) -> KVConnectorBase: connector_name = config.kv_transfer_config.kv_connector if connector_name not in cls._registry: raise ValueError(f"Unsupported connector type: {connector_name}") connector_cls = cls._registry[connector_name]() - return connector_cls(rank, local_rank, config) + return connector_cls(rank, local_rank, config, world_group) # Register various connectors here. diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py index 2033e976..71cd0567 100644 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -8,13 +8,15 @@ MooncakePipe. But the logic can be extended to support other pipe and lookup buffer. """ +import re from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch from vllm import _custom_ops as ops -from vllm.config import VllmConfig +from vllm.config import VllmConfig, KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( SimpleBuffer) from vllm.logger import init_logger @@ -33,10 +35,10 @@ class SimpleConnector(KVConnectorBase): rank: int, local_rank: int, config: VllmConfig, + world_group, ): self.config = config.kv_transfer_config - self.tp_size = config.parallel_config.tensor_parallel_size if self.config.kv_connector == "PyNcclConnector": from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( @@ -71,20 +73,31 @@ class SimpleConnector(KVConnectorBase): self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] + self._broadcast_and_enhance_kv_config(rank, config, world_group) + + self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) + self.tp_size = config.parallel_config.tensor_parallel_size + # 2 pipes for every rank in the world - port_offset_base = 2 * rank + if self.config.is_kv_producer: + port_offset_base = 2 * rank + 1 + else: + port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1 + self.local_kv_rank = rank % self.config.tensor_parallel_multiplier # In disaggregated prefill, the prefill vLLM only uses send pipe # and the decode vLLM only uses recv pipe if self.config.is_kv_producer: if self.config.kv_connector == "PyNcclConnector": self.producer_data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base, ) self.producer_signal_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base + 1, @@ -108,11 +121,13 @@ class SimpleConnector(KVConnectorBase): # its recv pipe to the send pipe of KV producder if self.config.kv_connector == "PyNcclConnector": self.consumer_data_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base, ) self.consumer_signal_pipe = PyNcclPipe( + kv_group_rank=self.kv_group_rank, local_rank=local_rank, config=self.config, port_offset=port_offset_base + 1, @@ -131,21 +146,25 @@ class SimpleConnector(KVConnectorBase): self.config.kv_buffer_size, ) - def select(self, input_tokens: Optional[torch.Tensor], + def select(self, source_rank: int, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + logger.info("Selecting KV caches and hidden states for source rank %d", source_rank) + assert self.consumer_buffer is not None, "Please initialize the "\ "consumer buffer before calling select." - return self.consumer_buffer.drop_select(input_tokens, roi) + return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: + logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank) + assert self.producer_buffer is not None, "Please initialize the "\ "producer buffer before calling insert." - self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden) def send_kv_caches_and_hidden_states( self, @@ -161,6 +180,7 @@ class SimpleConnector(KVConnectorBase): slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer + request_ids = list(model_input.request_ids_to_seq_ids.keys()) model_config = model_executable.model.config num_heads = int(model_config.num_key_value_heads / self.tp_size) @@ -175,27 +195,36 @@ class SimpleConnector(KVConnectorBase): start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + _, decode_kv_rank = self.parse_request_id(current_request_id) + starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config) + + for target_rank in range(self.config.tensor_parallel_multiplier): + + keys, values = [], [] - keys, values = [], [] + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier + head_start = target_rank * num_heads_per_rank + head_end = head_start + num_heads_per_rank - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) + keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) + values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) - self.insert(current_tokens, - torch.ones_like(current_tokens, - dtype=bool), keys, values, - hidden_or_intermediate_states[start_pos:end_pos]) + self.insert(starting_kv_group_rank, target_rank, current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) @@ -215,6 +244,7 @@ class SimpleConnector(KVConnectorBase): input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + request_ids = list(model_input.request_ids_to_seq_ids.keys()) hidden_or_intermediate_states_for_one_req = [] @@ -229,13 +259,15 @@ class SimpleConnector(KVConnectorBase): start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos] + current_request_id = request_ids[idx] + prefill_rank, _ = self.parse_request_id(current_request_id) num_tokens = slen # collecting data for rebuilding the input input_tokens_list.append(current_tokens) start_pos_list.append(start_pos) - ret = self.select(current_tokens, + ret = self.select(prefill_rank, current_tokens, torch.ones_like(current_tokens, dtype=bool)) if ret[0] is None: # didn't find any match. @@ -312,3 +344,77 @@ class SimpleConnector(KVConnectorBase): # MooncakePipe reuses data_pipe for signal_pipe, so we only have to # close the data_pipe. pass + + @staticmethod + def parse_request_id(request_id): + # Regular expression to match the ranks + pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)" + + # Use re.search to find the pattern in the request_id + match = re.search(pattern, request_id) + + if match: + # Extract the ranks + prefill_rank = int(match.group(1)) + decode_rank = int(match.group(2)) + + return prefill_rank, decode_rank + else: + return None, None + + + + def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: + if kv_rank < config.kv_producers_parallel_size: + return kv_rank + + kv_consumer_rank = kv_rank - config.kv_producers_parallel_size + return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier + + def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): + if rank == 0: + if self.config.kv_connector == "PyNcclConnector": + config_group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port, + rank=self.config.kv_rank, + world_size=self.config.kv_parallel_size, + ) + parallel_configs = config_group.all_gather_obj({ + "kv_role": self.config.kv_role, + "tensor_parallel_size": config.parallel_config.tensor_parallel_size, + "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, + }) + logger.debug("parallel_configs: %s", parallel_configs) + kv_config_enhanced = { + "kv_producers_tensor_parallel_size": None, + "kv_consumers_tensor_parallel_size": None, + "kv_producers_pipeline_parallel_size": None, + "kv_consumers_pipeline_parallel_size": None, + "kv_producers_parallel_size": 0, + } + for parallel_config in parallel_configs: + kv_role = parallel_config["kv_role"] + assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" + + if kv_role == "kv_producer": + kv_config_enhanced["kv_producers_parallel_size"] += 1 + if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: + kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] + kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] + else: + assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" + assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" + world_group.broadcast_object(kv_config_enhanced) + + else: + raise NotImplementedError("MooncakeConnector is not supported in Triton Distributed vllm patch") + else: + kv_config_enhanced = world_group.broadcast_object() + logger.info("kv_config_enhanced: %s", kv_config_enhanced) + + self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] + self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] + self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] + self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] + self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 5e1b6235..b4506877 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -12,7 +12,8 @@ import threading import time from collections import deque -from typing import Deque, List, Optional, Union +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, List, Optional, Union, Dict import torch @@ -46,7 +47,7 @@ class SimpleBuffer(KVLookupBufferBase): self.buffer_lock = threading.Lock() self.signal_pipe = signal_pipe self.data_pipe = data_pipe - self.request_handling_thread: Optional[threading.Thread] = None + self.request_handling_thread: Optional[ThreadPoolExecutor] = None self.normal_signal = torch.tensor([0], device="cpu") self.end_signal = None @@ -57,10 +58,16 @@ class SimpleBuffer(KVLookupBufferBase): # tokens_roi_sender: tokens and roi of the producer (in the buffer) # tokens_roi_recver: tokens and roi of the consumer (query) - tokens_sender = tokens_roi_sender[0] - tokens_recver = tokens_roi_recver[0] - roi_sender = tokens_roi_sender[1] - roi_recver = tokens_roi_recver[1] + target_rank_sender = tokens_roi_sender[0] + target_rank_recver = tokens_roi_recver[0] + + if target_rank_sender.item() != target_rank_recver.item(): + return 0 + + tokens_sender = tokens_roi_sender[1] + tokens_recver = tokens_roi_recver[1] + roi_sender = tokens_roi_sender[2] + roi_recver = tokens_roi_recver[2] if tokens_recver is None: # consumer sends an empty request @@ -80,14 +87,14 @@ class SimpleBuffer(KVLookupBufferBase): return 0 - def _send_tensor_and_dec_size(self, - tensor: Optional[torch.Tensor]) -> None: + def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor], + target_rank: int) -> None: assert tensor is not None, "Use self.data_pipe.send(None) instead" self.buffer_size -= tensor.element_size() * tensor.numel() if tensor.dtype == torch.bool: tensor = tensor.float() - self.data_pipe.send_tensor(tensor) + self.data_pipe.send_tensor(tensor, target_rank) def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): @@ -100,7 +107,7 @@ class SimpleBuffer(KVLookupBufferBase): raise AssertionError(f"Unknown data type {type(data)}") - def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def _add_to_buffer(self, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor): @@ -115,7 +122,7 @@ class SimpleBuffer(KVLookupBufferBase): if isinstance(hidden, torch.Tensor): hidden = hidden.clone() - buffer_item = [input_tokens, roi, key, value, hidden] + buffer_item = [torch.tensor(target_rank), input_tokens, roi, key, value, hidden] with self.buffer_lock: for data in buffer_item: @@ -125,53 +132,54 @@ class SimpleBuffer(KVLookupBufferBase): def _is_end_signal(self, signal): return signal is None - def drop_select_handler(self): + def drop_select_handler(self, rank: int): try: - while True: - signal = self.signal_pipe.recv_tensor() - if self._is_end_signal(signal): - logger.info("Received end signal!") - break - - input_tokens = self.data_pipe.recv_tensor() - - roi = self.data_pipe.recv_tensor() - assert roi is not None, "Please provide the roi when sending "\ - "drop-select request" - roi = (roi > 0.5) - tokens_roi_recver = [input_tokens, roi] - - matched_length = 0 - - # perform input tokens and roi matching - # FIXME: this matching is O(n), ideally it should be O(1) - # but this buffer size won't (and shouldn't) be too large so - # the fix is not urgent. - with self.buffer_lock: - - for _ in range(len(self.buffer)): - - temp_length = self._matches(self.buffer[0], - tokens_roi_recver) - if temp_length > 0: - matched_length = temp_length - break - # rotate the element we just accessed to the end - self.buffer.rotate(-1) - - if matched_length > 0: - # need to clone the tensor - # in case the tensor is freed before sending finishes - matched_item = self.buffer.popleft() - for tensor in matched_item: - self._send_tensor_and_dec_size(tensor) - - else: - # no match, just send None - for _ in range(5): - self.data_pipe.send_tensor(None) + signal = self.signal_pipe.recv_tensor(rank) + if self._is_end_signal(signal): + logger.info("Received end signal!") + return + target_kv_rank = self.data_pipe.recv_tensor(rank) + # assert target_rank.item() == rank, "Target rank does not match"\ + # "the rank of the drop-select handler" + input_tokens = self.data_pipe.recv_tensor(rank) + roi = self.data_pipe.recv_tensor(rank) + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" + roi = (roi > 0.5) + tokens_roi_recver = [target_kv_rank, input_tokens, roi] + + matched_length = 0 + + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + with self.buffer_lock: + + for _ in range(len(self.buffer)): + + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) + if temp_length > 0: + matched_length = temp_length + break + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + + if matched_length > 0: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + target_rank = matched_item[0].item() + for tensor in matched_item[1:]: + self._send_tensor_and_dec_size(tensor, rank) + + else: + # no match, just send None + for _ in range(5): + self.data_pipe.send_tensor(None, rank) except RuntimeError as e: if 'Connection closed by peer' not in str(e): @@ -180,10 +188,10 @@ class SimpleBuffer(KVLookupBufferBase): logger.debug("Closing drop_select_handler") def drop_select( - self, input_tokens: Optional[torch.Tensor], + self, rank: int, kv_rank: int, input_tokens: Optional[torch.Tensor], roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - assert self.request_handling_thread is None, \ + assert not self.request_handling_thread, \ "drop_select should be called by the KV cache consumer "\ "(e.g. the decode vLLM instance)" @@ -192,26 +200,28 @@ class SimpleBuffer(KVLookupBufferBase): if isinstance(roi, torch.Tensor): roi = roi.clone().float() - self.signal_pipe.send_tensor(self.normal_signal) - self.data_pipe.send_tensor(input_tokens) - self.data_pipe.send_tensor(roi) + self.signal_pipe.send_tensor(self.normal_signal, rank) + + self.data_pipe.send_tensor(torch.tensor(kv_rank), rank) + self.data_pipe.send_tensor(input_tokens, rank) + self.data_pipe.send_tensor(roi, rank) - input_tokens = self.data_pipe.recv_tensor() - roi = self.data_pipe.recv_tensor() + input_tokens = self.data_pipe.recv_tensor(rank) + roi = self.data_pipe.recv_tensor(rank) if roi is not None: # convert from float tensor to bool tensor # as PyNccl does not support sending bool tensor roi = (roi > 0.5) - key = self.data_pipe.recv_tensor() - value = self.data_pipe.recv_tensor() - hidden = self.data_pipe.recv_tensor() + key = self.data_pipe.recv_tensor(rank) + value = self.data_pipe.recv_tensor(rank) + hidden = self.data_pipe.recv_tensor(rank) return [input_tokens, roi, key, value, hidden] def full_handler(self): time.sleep(0.001) - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, key: torch.Tensor, value: torch.Tensor, hidden: torch.Tensor) -> None: @@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase): while self.buffer_size > self.buffer_size_threshold: self.full_handler() - self._add_to_buffer(input_tokens, roi, key, value, hidden) + self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden) # when calling the insert, the current process is a sender # need to launch the request handler and start listening to request. + target_rank_global = target_rank + kv_group_rank if self.request_handling_thread is None: - self.request_handling_thread = threading.Thread( - target=self.drop_select_handler) - self.request_handling_thread.start() + self.request_handling_thread = ThreadPoolExecutor(max_workers=1) + self.request_handling_thread.submit(self.drop_select_handler, target_rank_global) def close(self): - if hasattr(self, "request_handling_thread" - ) and self.request_handling_thread is not None: - self.request_handling_thread.join() + if hasattr(self, "request_handling_thread") and self.request_handling_thread: + self.request_handling_thread.shutdown() else: # TODO: have a explicit close signal and have a explicit way to diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py index 40589fb3..da2829cf 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/base.py +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -23,7 +23,7 @@ class KVPipeBase(ABC): """ @abstractmethod - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None: """Send a tensor, or None, via the pipe. Need to support sending None -- important for error handling. @@ -41,7 +41,7 @@ class KVPipeBase(ABC): raise NotImplementedError @abstractmethod - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: """Receive a tensor (can be None) from the pipeline. Returns: diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 7aa53d07..db10f8a0 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -45,33 +45,33 @@ class PyNcclPipe(KVPipeBase): METADATA_DTYPE = torch.int64 def __init__(self, + kv_group_rank: int, local_rank: int, config: KVTransferConfig, device: Optional[str] = None, port_offset: int = 0): self.config = config self.local_rank = local_rank - self.kv_rank = self.config.kv_rank + self.kv_group_rank = kv_group_rank self.kv_parallel_size = self.config.kv_parallel_size + self.kv_world_size = self.config.kv_world_size if device is None: self.device = self._select_device(self.config.kv_buffer_device) else: self.device = self._select_device(device) # build distributed connection and send/recv implementation + logger.info("Creating process group for kv transfer with rank %d and world size %d, ip: %s, port: %d", self.kv_group_rank, self.kv_world_size, self.config.kv_ip, self.config.kv_port + port_offset) self.group = StatelessProcessGroup.create( host=self.config.kv_ip, port=self.config.kv_port + port_offset, - rank=self.kv_rank, - world_size=self.kv_parallel_size, + rank=self.kv_group_rank, + world_size=self.kv_world_size, ) # add a barrier to make sure the connection is initiated properly self.group.barrier() impl = self._get_device_send_recv_impl(self.group) self.device_send_func, self.device_recv_func = impl - # set target rank - self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size - self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size # transportation-related variables self.transport_thread: Optional[ThreadPoolExecutor] = None @@ -145,16 +145,16 @@ class PyNcclPipe(KVPipeBase): dtype=metadata["dtype"], device=self.device) - def _send_metadata(self, metadata: Metadata): + def _send_metadata(self, metadata: Metadata, target_rank: int): """ Send the metadata dictionary to the target rank. Parameters: - metadata: A dictionary with keys "dtype" and "shape". """ - self.group.send_obj(metadata, self.target_rank_for_send) + self.group.send_obj(metadata, target_rank) - def _recv_metadata(self) -> Metadata: + def _recv_metadata(self, src_rank: int) -> Metadata: """ Receive the metadata dictionary from the target rank. @@ -162,9 +162,9 @@ class PyNcclPipe(KVPipeBase): - metadata: A dictionary with keys "dtype" and "shape" describing the tensor. """ - return self.group.recv_obj(self.target_rank_for_recv) + return self.group.recv_obj(src_rank) - def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + def _send_impl(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: """ The actual implementation of sending the tensor and its metadata to the target rank. @@ -174,12 +174,12 @@ class PyNcclPipe(KVPipeBase): being sent. """ metadata = self._make_metadata(tensor) - self._send_metadata(metadata) + self._send_metadata(metadata, target_rank) if tensor is not None: self.device_send_func(tensor.to(self.device), - self.target_rank_for_send) + target_rank) - def _recv_impl(self) -> Optional[torch.Tensor]: + def _recv_impl(self, src_rank: int) -> Optional[torch.Tensor]: """ The actual implementation of receiving a tensor and its metadata from the target rank. @@ -187,21 +187,22 @@ class PyNcclPipe(KVPipeBase): Returns: - buffer: The received tensor, or None if no tensor is received. """ - metadata = self._recv_metadata() + metadata = self._recv_metadata(src_rank) if metadata["dtype"] is None: return None buffer = self._prepare_recv_buffer(metadata) - self.device_recv_func(buffer, self.target_rank_for_recv) + self.device_recv_func(buffer, src_rank) return buffer def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], - tensor_size: int) -> None: + tensor_size: int, + target_rank: int) -> None: """ Wrapper for _send_impl to handle exceptions and update buffer size. """ try: - self._send_impl(tensor) + self._send_impl(tensor, target_rank) with self.buffer_size_lock: self.buffer_size -= tensor_size @@ -220,7 +221,7 @@ class PyNcclPipe(KVPipeBase): logger.debug("KV cache transfer pipe is full. Waiting...") time.sleep(0.05) - def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: """ Sends a tensor and its metadata to the destination rank in a non-blocking way. @@ -228,6 +229,7 @@ class PyNcclPipe(KVPipeBase): Parameters: - tensor: The tensor to send, or None if no tensor is being sent. """ + logger.debug("Rank %d sending tensor of shape %s dtype %s to rank %d", self.kv_group_rank, tensor.shape if tensor is not None else "None", tensor.dtype if tensor is not None else "None", target_rank) if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) @@ -242,19 +244,23 @@ class PyNcclPipe(KVPipeBase): self.buffer_size += tensor_size self.transport_thread.submit(self.send_tensor_wrapper, tensor, - tensor_size) + tensor_size, + target_rank) - def recv_tensor(self) -> Optional[torch.Tensor]: + def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: """ Receives a tensor and its metadata from the source rank. Blocking call. Returns: - tensor: The received tensor, or None if no tensor is received. """ + + logger.debug("Rank %d receiving tensor from rank %d", self.kv_group_rank, src_rank) + if self.transport_thread is None: self.transport_thread = ThreadPoolExecutor(max_workers=1) - future = self.transport_thread.submit(self._recv_impl) + future = self.transport_thread.submit(self._recv_impl, src_rank) try: tensor = future.result() diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py index 1e80e0bd..cd90206f 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_agent.py +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -35,6 +35,7 @@ class KVTransferAgent: rank: int, local_rank: int, config: "VllmConfig", + world_group, ): self.config = config @@ -47,7 +48,7 @@ class KVTransferAgent: "TransferAgent should only be used when kv_connector is set." self.connector = KVConnectorFactory.create_connector( - rank, local_rank, config) + rank, local_rank, config, world_group) def send_kv_caches_and_hidden_states( self, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 321902d1..b8937ef8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1085,7 +1085,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: _KV_TRANSFER = kv_transfer.KVTransferAgent( rank=get_world_group().rank, local_rank=get_world_group().local_rank, - config=vllm_config) + config=vllm_config, + world_group=get_world_group()) def ensure_model_parallel_initialized(