"""Tensor-parallel helpers for kvprune when embedded in a vLLM worker.""" from __future__ import annotations import torch.distributed as dist def tensor_parallel_rank_for_sharding() -> int: """Rank within the tensor-parallel group (matches vLLM weight shards when embedded). Falls back to :func:`torch.distributed.get_rank` when vLLM parallel state is unavailable (standalone kvprune with only the default process group). """ try: from vllm.distributed.parallel_state import get_tensor_model_parallel_rank return int(get_tensor_model_parallel_rank()) except Exception: if dist.is_initialized(): return int(dist.get_rank()) return 0 def tensor_parallel_world_size_for_sharding() -> int: """World size of the tensor-parallel group.""" try: from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, ) return int(get_tensor_model_parallel_world_size()) except Exception: if dist.is_initialized(): return int(dist.get_world_size()) return 1 def kv_heads_shard_divisor() -> int: """Return world size used to shard KV heads (TP group when vLLM is loaded).""" return tensor_parallel_world_size_for_sharding()