dp_utils.py 2.87 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.distributed as dist

from vllm.distributed.parallel_state import get_dp_group


9
10
11
12
13
14
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
    if dp_size == 1:
        return None
    return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")


Woosuk Kwon's avatar
Woosuk Kwon committed
15
def get_batch_metadata_across_dp(
16
17
18
19
20
21
    num_tokens: int,
    cudagraph_size: int,
    cudagraph_runtime_mode: int,
    dp_size: int,
    dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
    assert dp_size > 1
    # Use CPU group to avoid CPU-GPU synchronization.
    group = get_dp_group().cpu_group
25
    tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
    tensor[0][dp_rank] = num_tokens
    tensor[1][dp_rank] = cudagraph_size
28
    tensor[2][dp_rank] = cudagraph_runtime_mode
Woosuk Kwon's avatar
Woosuk Kwon committed
29
    dist.all_reduce(tensor, group=group)
30
    return tensor[0], tensor[1], tensor[2]
31
32


33
def get_cudagraph_and_dp_padding(
34
35
36
37
38
39
    num_tokens: int,
    cudagraph_size: int | None,
    cudagraph_runtime_mode: int,
    dp_size: int,
    dp_rank: int,
) -> tuple[int, torch.Tensor | None, int]:
40
    if dp_size == 1:
41
        if cudagraph_size is not None:
42
            return cudagraph_size, None, cudagraph_runtime_mode
43
        else:
44
            return num_tokens, None, cudagraph_runtime_mode
45

46
    # Convert None to -1 for sync (indicates no cudagraph available)
47
48
49
50
    if num_tokens == 0:
        cudagraph_size = 0
    elif cudagraph_size is None:
        cudagraph_size = -1
51
52
53
54
55

    num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
        get_batch_metadata_across_dp(
            num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
        )
56
57
58
    )
    if torch.all(num_tokens_across_dp == 0).item():
        # All ranks have zero tokens to run.
59
60
61
62
63
64
        return 0, None, 0

    # Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
    synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
    # Check if all ranks have valid cudagraph_size.
    all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
65

66
67
    if synced_cudagraph_mode != 0 and all_have_cudagraph:
        # All ranks use cudagraph. Pad to max cudagraph_size.
68
69
        max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
        num_tokens_across_dp[:] = max_cudagraph_size
70
        return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
71
    else:
72
73
74
        # Fall back to eager mode (no cudagraph).
        # Either some rank doesn't have cudagraph size or mode is NONE.
        synced_cudagraph_mode = 0
75
76
        num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
        num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
77
        return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode