test_cpu_gpu.py 7.67 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import time

import pytest
import torch

from vllm.platforms import current_platform
10
from vllm.utils.torch_utils import set_random_seed
11
12
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
13
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
14

15
16
17
18
19
20
21
22
23
24
25
BACKENDS_TO_TEST = [FlashAttentionBackend]

if not current_platform.is_rocm():
    from vllm.v1.attention.backends.flashinfer import FlashInferBackend

    BACKENDS_TO_TEST.append(FlashInferBackend)

    from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend

    BACKENDS_TO_TEST.append(FlashAttnMLABackend)

26
27
NUM_GPU_BLOCKS = [64]
NUM_CPU_BLOCKS = [256]
28
29
30
KERNEL_BLOCK_SIZES = [16]
LOGICAL_BLOCK_SIZES = [16, 32]
LOGICAL_BLOCKS_PER_CPU_BLOCK = [1, 3]
31
32
33
34
35
HEAD_SIZES = [64]
NUM_HEADS = [8]
NUM_LAYERS = [4]
DTYPES = [torch.bfloat16]
SEEDS = [0]
36
CUDA_DEVICES = ["cuda:0"]
37
38
39
40
41
42
43
NUM_MAPPINGS = [3]


@pytest.mark.parametrize("gpu_to_cpu", [True, False])
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
44
45
46
@pytest.mark.parametrize("kernel_block_size", KERNEL_BLOCK_SIZES)
@pytest.mark.parametrize("logical_block_size", LOGICAL_BLOCK_SIZES)
@pytest.mark.parametrize("logical_blocks_per_cpu_block", LOGICAL_BLOCKS_PER_CPU_BLOCK)
47
48
49
50
51
52
53
54
@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS)
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_transfer(
55
    default_vllm_config,
56
57
58
59
    gpu_to_cpu: bool,
    num_mappings: int,
    head_size: int,
    num_heads: int,
60
61
62
    kernel_block_size: int,
    logical_block_size: int,
    logical_blocks_per_cpu_block: int,
63
64
65
66
67
68
69
    num_gpu_blocks: int,
    num_cpu_blocks: int,
    num_layers: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
70
    set_random_seed(seed)
71

72
73
    # create per-layer GPU KV caches based on available attn_backends
    attn_backends_list = BACKENDS_TO_TEST
74

75
76
77
78
    assert logical_block_size % kernel_block_size == 0
    kernel_blocks_per_gpu_block = logical_block_size // kernel_block_size
    num_gpu_kernel_blocks = num_gpu_blocks * kernel_blocks_per_gpu_block

79
80
81
    gpu_caches = {}
    attn_backends = {}
    for i in range(num_layers):
82
        layer_name = f"layer {i}"
83
84
85
86
87

        attn_backend = attn_backends_list[i % len(attn_backends_list)]
        attn_backends[layer_name] = attn_backend

        gpu_cache_shape = attn_backend.get_kv_cache_shape(
88
            num_gpu_kernel_blocks, kernel_block_size, num_heads, head_size
89
90
        )
        gpu_caches[layer_name] = torch.rand(gpu_cache_shape, dtype=dtype, device=device)
91
92

    # create handler
93
94
    cpu_block_size = logical_blocks_per_cpu_block * logical_block_size
    kernel_blocks_per_cpu_block = cpu_block_size // kernel_block_size
95
    handlers = CpuGpuOffloadingHandlers(
96
        attn_backends=attn_backends,
97
        gpu_block_size=logical_block_size,
98
99
100
101
        cpu_block_size=cpu_block_size,
        num_cpu_blocks=num_cpu_blocks,
        gpu_caches=gpu_caches,
    )
102
103

    # select block mappings
104
    gpu_blocks = random.sample(
105
        range(num_gpu_blocks), num_mappings * logical_blocks_per_cpu_block
106
    )
107
108
    cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)

109
110
111
112
113
114
115
    # convert gpu blocks to kernel block size
    gpu_blocks_in_kernel_block_size = []
    for gpu_block in gpu_blocks:
        base_block_id = gpu_block * kernel_blocks_per_gpu_block
        for i in range(kernel_blocks_per_gpu_block):
            gpu_blocks_in_kernel_block_size.append(i + base_block_id)

116
    # convert cpu blocks to gpu block size
117
    cpu_blocks_in_kernel_block_size = []
118
    for cpu_block in cpu_blocks:
119
120
121
        base_block_id = cpu_block * kernel_blocks_per_cpu_block
        for i in range(kernel_blocks_per_cpu_block):
            cpu_blocks_in_kernel_block_size.append(i + base_block_id)
122

123
    # maybe skip some GPU block to test reading from the middle of a CPU block
124
    if not gpu_to_cpu:
125
126
127
128
129
130
131
132
        gpu_blocks_to_skip = logical_blocks_per_cpu_block - 1
        gpu_blocks = gpu_blocks[gpu_blocks_to_skip:]
        kernel_blocks_to_skip = gpu_blocks_to_skip * kernel_blocks_per_gpu_block
        gpu_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size[
            kernel_blocks_to_skip:
        ]
        cpu_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size[
            kernel_blocks_to_skip:
133
        ]
134
135
136

    # set transfer direction
    if gpu_to_cpu:
137
        handler = handlers.gpu_to_cpu_handler
138
139
        src_blocks = gpu_blocks
        dst_blocks = cpu_blocks
140
141
        src_spec = GPULoadStoreSpec(src_blocks, group_sizes=(len(src_blocks),))
        dst_spec = CPULoadStoreSpec(dst_blocks)
142
143
144
        src_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size
        dst_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size
        dst_size_in_kernel_blocks = num_cpu_blocks * kernel_blocks_per_cpu_block
145
    else:
146
        handler = handlers.cpu_to_gpu_handler
147
148
        src_blocks = cpu_blocks
        dst_blocks = gpu_blocks
149
150
        src_spec = CPULoadStoreSpec(src_blocks)
        dst_spec = GPULoadStoreSpec(dst_blocks, group_sizes=(len(dst_blocks),))
151
152
153
        src_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size
        dst_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size
        dst_size_in_kernel_blocks = num_gpu_blocks * kernel_blocks_per_gpu_block
154
155
156

    # build dst -> src mapping
    dst_to_src = {}
157
    for src_block, dst_block in zip(
158
        src_blocks_in_kernel_block_size, dst_blocks_in_kernel_block_size
159
    ):
160
161
162
        dst_to_src[dst_block] = src_block

    # clone src and dst tensors before transfer
163
164
    orig_src_caches = [x.clone() for x in handler.src_tensors]
    orig_dst_caches = [x.clone() for x in handler.dst_tensors]
165
166

    # call transfer function
167
    start_time = time.time()
168
    assert handler.transfer_async(1, (src_spec, dst_spec))
169
    assert set({x.job_id for x in handler._transfers}) == {1}
170
171
172
173
174
175

    # wait for transfer to complete
    end_time = time.time() + 10
    while time.time() < end_time:
        finished = handler.get_finished()
        if finished:
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
            assert finished[0].job_id == 1
            assert finished[0].success
            assert (
                finished[0].transfer_type == ("GPU", "CPU")
                if gpu_to_cpu
                else ("CPU", "GPU")
            )
            assert (
                finished[0].transfer_size
                == handler.total_block_size_in_bytes
                * handler.dst_block_size_factor
                * len(dst_blocks)
            )
            assert finished[0].transfer_time > 0
            assert finished[0].transfer_time < (time.time() - start_time)
191
192
193
194
            break
        time.sleep(0.1)

    # verify src tensors did not change
195
    for orig_tensor, tensor in zip(orig_src_caches, handler.src_tensors):
196
197
198
        assert torch.equal(orig_tensor, tensor)

    # verify dst tensors
199
    for dst_block in range(dst_size_in_kernel_blocks):
200
        src_block_candidate = dst_to_src.get(dst_block)
201
        for src_cache, dst_cache, orig_dst_cache in zip(
202
203
            handler.src_tensors,
            handler.dst_tensors,
204
205
            orig_dst_caches,
        ):
206
207
            if src_block_candidate is not None:
                expected_value = src_cache[src_block_candidate]
208
            else:
209
210
                expected_value = orig_dst_cache[dst_block]
            torch.testing.assert_close(dst_cache[dst_block].cpu(), expected_value.cpu())