bench_zerocopy.py 3.5 KB
Newer Older
pansicheng's avatar
pansicheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import threading
import time

import torch
from tqdm import tqdm

from sglang.srt.distributed import (
    get_world_group,
    init_distributed_environment,
    initialize_model_parallel,
)
from sglang.srt.managers.cache_controller import (
    HiCacheController,
    PrefetchOperation,
    StorageOperation,
)
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost

init_distributed_environment(
    world_size=1,
    rank=0,
    distributed_init_method="tcp://127.0.0.1:23456",
    local_rank=0,
    backend="gloo",
)

initialize_model_parallel(
    tensor_model_parallel_size=1,
    pipeline_model_parallel_size=1,
)

group = get_world_group().cpu_group

max_total_num_tokens = 524288
page_size = 64
kv_cache_dtype = torch.bfloat16
layer_num = 64
head_num, head_dim = 8, 128
device = "cuda"
hicache_ratio = 2
hicache_size = 0
hicache_mem_layout = "page_first"
# hicache_mem_layout = "layer_first"
hicache_write_policy = "write_through"
hicache_io_backend = "kernel"
hicache_storage_backend = "hf3fs"
prefetch_threshold = 256

op_size = 1024
op_num = 16

token_to_kv_pool = MHATokenToKVPool(
    max_total_num_tokens,
    page_size=page_size,
    dtype=kv_cache_dtype,
    head_num=head_num,
    head_dim=head_dim,
    layer_num=layer_num,
    device=device,
    enable_memory_saver=True,
)

token_to_kv_pool_allocator = TokenToKVPoolAllocator(
    max_total_num_tokens,
    dtype=kv_cache_dtype,
    device=device,
    kvcache=token_to_kv_pool,
    need_sort=False,
)

kv_cache = token_to_kv_pool_allocator.get_kvcache()
token_to_kv_pool_host = MHATokenToKVPoolHost(
    kv_cache,
    hicache_ratio,
    hicache_size,
    page_size,
    hicache_mem_layout,
)

load_cache_event = threading.Event()
cache_controller = HiCacheController(
    token_to_kv_pool_allocator,
    token_to_kv_pool_host,
    page_size,
    group,
    load_cache_event=load_cache_event,
    write_policy=hicache_write_policy,
    io_backend=hicache_io_backend,
    storage_backend=hicache_storage_backend,
    prefetch_threshold=prefetch_threshold,
)

operations = [
    StorageOperation(
        torch.tensor(list(range(i, i + op_size))),
        list(range(i, i + op_size)),
        hash_value=[f"{j}" for j in range(i, i + op_size, page_size)],
    )
    for i in tqdm(range(0, op_num * op_size, op_size))
]

tik = time.monotonic()
if hicache_mem_layout == "page_first":
    for operation in operations:
        cache_controller.zerocopy_page_backup(operation, batch_size=128)
elif hicache_mem_layout == "layer_first":
    for operation in operations:
        cache_controller.generic_page_backup(operation, batch_size=128)
tok = time.monotonic()
print(f"{tok-tik:.6f} s")

operations = [
    PrefetchOperation(
        f"{i}",
        torch.tensor(list(range(i, i + op_size))),
        list(range(i, i + op_size)),
        f"{i}",
    )
    for i in tqdm(range(0, op_num * op_size, op_size))
]

for operation in operations:
    operation.hash_value = [
        f"{j}"
        for j in range(
            int(operation.last_hash), int(operation.last_hash) + op_size, page_size
        )
    ]

tik = time.monotonic()
if hicache_mem_layout == "page_first":
    for operation in operations:
        cache_controller.zerocopy_page_transfer(operation, batch_size=128)
elif hicache_mem_layout == "layer_first":
    for operation in operations:
        cache_controller.generic_page_transfer(operation, batch_size=128)
tok = time.monotonic()
print(f"{tok-tik:.6f} s")