test_cpu_gpu.py 6.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import time

import pytest
import torch

from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
12
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
13

14
15
16
17
18
19
20
21
22
23
24
BACKENDS_TO_TEST = [FlashAttentionBackend]

if not current_platform.is_rocm():
    from vllm.v1.attention.backends.flashinfer import FlashInferBackend

    BACKENDS_TO_TEST.append(FlashInferBackend)

    from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend

    BACKENDS_TO_TEST.append(FlashAttnMLABackend)

25
26
27
28
29
30
31
32
33
NUM_GPU_BLOCKS = [64]
NUM_CPU_BLOCKS = [256]
GPU_BLOCK_SIZES = [16]
GPU_BLOCKS_PER_CPU_BLOCK = [1, 3]
HEAD_SIZES = [64]
NUM_HEADS = [8]
NUM_LAYERS = [4]
DTYPES = [torch.bfloat16]
SEEDS = [0]
34
CUDA_DEVICES = ["cuda:0"]
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
NUM_MAPPINGS = [3]


@pytest.mark.parametrize("gpu_to_cpu", [True, False])
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES)
@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK)
@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS)
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_transfer(
    gpu_to_cpu: bool,
    num_mappings: int,
    head_size: int,
    num_heads: int,
    gpu_block_size: int,
    gpu_blocks_per_cpu_block: int,
    num_gpu_blocks: int,
    num_cpu_blocks: int,
    num_layers: int,
    dtype: torch.dtype,
    seed: int,
    device: str,
) -> None:
    current_platform.seed_everything(seed)

67
68
    # create per-layer GPU KV caches based on available attn_backends
    attn_backends_list = BACKENDS_TO_TEST
69
70
71
72

    gpu_caches = {}
    attn_backends = {}
    for i in range(num_layers):
73
        layer_name = f"layer {i}"
74
75
76
77
78

        attn_backend = attn_backends_list[i % len(attn_backends_list)]
        attn_backends[layer_name] = attn_backend

        gpu_cache_shape = attn_backend.get_kv_cache_shape(
79
80
81
            num_gpu_blocks, gpu_block_size, num_heads, head_size
        )
        gpu_caches[layer_name] = torch.rand(gpu_cache_shape, dtype=dtype, device=device)
82
83
84

    # create handler
    cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
85
    handlers = CpuGpuOffloadingHandlers(
86
87
88
89
90
91
        attn_backends=attn_backends,
        gpu_block_size=gpu_block_size,
        cpu_block_size=cpu_block_size,
        num_cpu_blocks=num_cpu_blocks,
        gpu_caches=gpu_caches,
    )
92
93

    # select block mappings
94
95
96
    gpu_blocks = random.sample(
        range(num_gpu_blocks), num_mappings * gpu_blocks_per_cpu_block
    )
97
98
99
100
101
102
103
104
105
    cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)

    # convert cpu blocks to gpu block size
    cpu_blocks_in_gpu_block_size = []
    for cpu_block in cpu_blocks:
        base_block_id = cpu_block * gpu_blocks_per_cpu_block
        for i in range(gpu_blocks_per_cpu_block):
            cpu_blocks_in_gpu_block_size.append(i + base_block_id)

106
107
    # maybe skip a GPU block to test reading from the middle of a CPU block
    if not gpu_to_cpu:
108
        gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1 :]
109
        cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[
110
111
            gpu_blocks_per_cpu_block - 1 :
        ]
112
113
114

    # set transfer direction
    if gpu_to_cpu:
115
        handler = handlers.gpu_to_cpu_handler
116
117
118
119
120
121
122
123
        src_spec_class = GPULoadStoreSpec
        dst_spec_class = CPULoadStoreSpec
        src_blocks = gpu_blocks
        dst_blocks = cpu_blocks
        src_blocks_in_gpu_block_size = gpu_blocks
        dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
        dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
    else:
124
        handler = handlers.cpu_to_gpu_handler
125
126
127
128
129
130
131
132
133
134
        src_spec_class = CPULoadStoreSpec
        dst_spec_class = GPULoadStoreSpec
        src_blocks = cpu_blocks
        dst_blocks = gpu_blocks
        src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
        dst_blocks_in_gpu_block_size = gpu_blocks
        dst_size_in_gpu_blocks = num_gpu_blocks

    # build dst -> src mapping
    dst_to_src = {}
135
136
137
    for src_block, dst_block in zip(
        src_blocks_in_gpu_block_size, dst_blocks_in_gpu_block_size
    ):
138
139
140
141
142
143
144
        dst_to_src[dst_block] = src_block

    # build transfer specs
    src_spec = src_spec_class(src_blocks)
    dst_spec = dst_spec_class(dst_blocks)

    # clone src and dst tensors before transfer
145
146
    orig_src_caches = [x.clone() for x in handler.src_tensors]
    orig_dst_caches = [x.clone() for x in handler.dst_tensors]
147
148
149

    # call transfer function
    assert handler.transfer_async(1, (src_spec, dst_spec))
150
    assert set({x[0] for x in handler._transfers}) == {1}
151
152
153
154
155
156
157
158
159
160
161

    # wait for transfer to complete
    end_time = time.time() + 10
    while time.time() < end_time:
        finished = handler.get_finished()
        if finished:
            assert finished == [(1, True)]
            break
        time.sleep(0.1)

    # verify src tensors did not change
162
    for orig_tensor, tensor in zip(orig_src_caches, handler.src_tensors):
163
164
165
166
167
168
        assert torch.equal(orig_tensor, tensor)

    # verify dst tensors
    for dst_block in range(dst_size_in_gpu_blocks):
        src_block_candidate = dst_to_src.get(dst_block)
        for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
169
170
            handler.src_tensors,
            handler.dst_tensors,
171
172
173
            orig_dst_caches,
            handler.kv_dim_before_num_blocks,
        ):
174
175
176
177
178
179
180
            if kv_dim:
                # iterate over key, value
                for i in range(2):
                    if src_block_candidate is not None:
                        expected_value = src_cache[i][src_block_candidate]
                    else:
                        expected_value = orig_dst_cache[i][dst_block]
181
182
183
                    torch.testing.assert_close(
                        dst_cache[i][dst_block].cpu(), expected_value.cpu()
                    )
184
185
186
187
188
            else:
                if src_block_candidate is not None:
                    expected_value = src_cache[src_block_candidate]
                else:
                    expected_value = orig_dst_cache[dst_block]
189
190
191
                torch.testing.assert_close(
                    dst_cache[dst_block].cpu(), expected_value.cpu()
                )