tp_utils.py 1.26 KB
Newer Older
chenzk's avatar
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""Tensor-parallel helpers for kvprune when embedded in a vLLM worker."""

from __future__ import annotations

import torch.distributed as dist


def tensor_parallel_rank_for_sharding() -> int:
    """Rank within the tensor-parallel group (matches vLLM weight shards when embedded).

    Falls back to :func:`torch.distributed.get_rank` when vLLM parallel state is
    unavailable (standalone kvprune with only the default process group).
    """
    try:
        from vllm.distributed.parallel_state import get_tensor_model_parallel_rank

        return int(get_tensor_model_parallel_rank())
    except Exception:
        if dist.is_initialized():
            return int(dist.get_rank())
        return 0


def tensor_parallel_world_size_for_sharding() -> int:
    """World size of the tensor-parallel group."""
    try:
        from vllm.distributed.parallel_state import (
            get_tensor_model_parallel_world_size,
        )

        return int(get_tensor_model_parallel_world_size())
    except Exception:
        if dist.is_initialized():
            return int(dist.get_world_size())
        return 1


def kv_heads_shard_divisor() -> int:
    """Return world size used to shard KV heads (TP group when vLLM is loaded)."""
    return tensor_parallel_world_size_for_sharding()