bench_storage.py 7.13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import json
import logging
import os
import random
import time
from typing import List

import torch
from tqdm import tqdm

pansicheng's avatar
pansicheng committed
11
12
13
from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
    Hf3fsLocalMetadataClient,
)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS


def print_stats(x: List[int]):
    x = sorted(x)
    lenx = len(x)
    print(
        f"mean = {sum(x)/len(x):.2f}, "
        f"min = {min(x):.2f}, "
        f"p25 = {x[int(lenx*0.25)]:.2f}, "
        f"p50 = {x[int(lenx*0.5)]:.2f}, "
        f"p75 = {x[int(lenx*0.75)]:.2f}, "
        f"max = {max(x):.2f}"
    )


def test():
    # Qwen3-32B
    layer_num = 64
    head_num, head_dim = 8, 128
    kv_lora_rank, qk_rope_head_dim = 0, 0
    store_dtype = torch.bfloat16
    tokens_per_page = 64

    file_path_prefix = "/data/test"
    file_size = 128 << 20
    numjobs = 16
    bytes_per_page = 16 << 20
    entries = 2
    dtype = store_dtype

    config_path = os.getenv(HiCacheHF3FS.default_env_var)
    assert config_path
    try:
        with open(config_path, "w") as f:
            json.dump(
                {
                    "file_path_prefix": file_path_prefix,
                    "file_size": file_size,
                    "numjobs": numjobs,
                    "entries": entries,
                },
                f,
            )
    except Exception as e:
        raise RuntimeError(f"Failed to dump config to {config_path}: {str(e)}")
60
    hicache_hf3fs = HiCacheHF3FS.from_env_config(bytes_per_page, dtype)
61
62
63
64
65
66
67
68
69
70

    numel = 2 * tokens_per_page * layer_num * head_num * head_dim
    assert numel * dtype.itemsize == bytes_per_page

    num_pages = 10
    tensors = {}
    for i in range(num_pages):
        k = f"key_{i}"
        v = torch.randn((numel,)).to(dtype=dtype)
        ok = hicache_hf3fs.set(k, v)
pansicheng's avatar
pansicheng committed
71
72
73
74
        if i < (file_size // bytes_per_page):
            assert ok, f"Failed to insert {k}"
        else:
            assert not ok
75
        tensors[k] = v
pansicheng's avatar
pansicheng committed
76
77
    assert hicache_hf3fs.get("key_8") is None
    assert hicache_hf3fs.get("key_9") is None
78

pansicheng's avatar
pansicheng committed
79
    start = 0
80
81
82
83
84
85
86
87
88
89
    for i in range(start, start + hicache_hf3fs.num_pages):
        k = f"key_{i}"
        assert hicache_hf3fs.exists(k)
        out = hicache_hf3fs.get(k)
        assert out is not None
        v = tensors[k]
        assert torch.allclose(v, out, atol=1e-3), f"Tensor mismatch for {k}"

    assert not hicache_hf3fs.exists("not_exists")

pansicheng's avatar
pansicheng committed
90
    hicache_hf3fs.delete("key_7")
91
92
93
94
95
    v2 = torch.randn((numel,)).to(dtype=dtype)
    assert hicache_hf3fs.set("key_new", v2)
    assert torch.allclose(hicache_hf3fs.get("key_new"), v2, atol=1e-3)

    hicache_hf3fs.clear()
pansicheng's avatar
pansicheng committed
96
97
98
99
    assert (
        len(hicache_hf3fs.metadata_client.rank_metadata.free_pages)
        == hicache_hf3fs.metadata_client.rank_metadata.num_pages
    )
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

    # batch
    num_pages = 10
    tensors = {}
    keys = []
    values = []
    for i in range(num_pages):
        k = f"key_{i}"
        keys.append(k)
        v = torch.randn((numel,)).to(dtype=dtype)
        values.append(v)

    ok = hicache_hf3fs.batch_set(keys, values)
    assert not ok
    assert hicache_hf3fs.get("key_8") is None
    assert hicache_hf3fs.get("key_9") is None

    results = hicache_hf3fs.batch_get(keys[: hicache_hf3fs.num_pages])
    for result, key, value in zip(
        results, keys[: hicache_hf3fs.num_pages], values[: hicache_hf3fs.num_pages]
    ):
        assert torch.allclose(value, result, atol=1e-3), f"Tensor mismatch for {key}"

    hicache_hf3fs.close()
    os.remove(hicache_hf3fs.file_path)

    print("All test cases passed.")


def bench():
    # Qwen3-32B
    layer_num = 64
    head_num, head_dim = 8, 128
    kv_lora_rank, qk_rope_head_dim = 0, 0
    store_dtype = torch.bfloat16
    tokens_per_page = 64

    file_path = "/data/test.bin"
    file_size = 1 << 40
    numjobs = 16
    bytes_per_page = 16 << 20
    entries = 8
    dtype = store_dtype
    hicache_hf3fs = HiCacheHF3FS(
pansicheng's avatar
pansicheng committed
144
        rank=0,
145
146
147
148
149
150
        file_path=file_path,
        file_size=file_size,
        numjobs=numjobs,
        bytes_per_page=bytes_per_page,
        entries=entries,
        dtype=dtype,
pansicheng's avatar
pansicheng committed
151
        metadata_client=Hf3fsLocalMetadataClient(),
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    )

    numel = 2 * tokens_per_page * layer_num * head_num * head_dim
    assert numel * dtype.itemsize == bytes_per_page

    num_page = 128
    values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]

    warmup = 50
    iteration = 100

    w_bw = []
    w_size = num_page * bytes_per_page / (1 << 30)
    for i in tqdm(range(warmup + iteration), desc="Benchmarking write (GB/s)"):
        keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
        tik = time.perf_counter()
        ok = hicache_hf3fs.batch_set(keys, values)
        tok = time.perf_counter()
        if i < warmup:
            continue
        w_bw.append(w_size / (tok - tik))
        assert ok
    print_stats(w_bw)

    r_bw = []
    r_size = num_page * bytes_per_page / (1 << 30)
    for i in tqdm(range(warmup + iteration), desc="Benchmarking read (GB/s)"):
pansicheng's avatar
pansicheng committed
179
180
181
182
        keys = random.sample(
            list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
            num_page,
        )
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        tik = time.perf_counter()
        results = hicache_hf3fs.batch_get(keys)
        tok = time.perf_counter()
        if i < warmup:
            continue
        r_bw.append(r_size / (tok - tik))
        assert all([r is not None for r in results])
    print_stats(r_bw)

    hicache_hf3fs.close()


def allclose():
    # Qwen3-32B
    layer_num = 64
    head_num, head_dim = 8, 128
    kv_lora_rank, qk_rope_head_dim = 0, 0
    store_dtype = torch.bfloat16
    tokens_per_page = 64

    file_path = "/data/test.bin"
    file_size = 1 << 40
    numjobs = 16
    bytes_per_page = 16 << 20
    entries = 8
    dtype = store_dtype
    hicache_hf3fs = HiCacheHF3FS(
pansicheng's avatar
pansicheng committed
210
        rank=0,
211
212
213
214
215
216
        file_path=file_path,
        file_size=file_size,
        numjobs=numjobs,
        bytes_per_page=bytes_per_page,
        entries=entries,
        dtype=dtype,
pansicheng's avatar
pansicheng committed
217
        metadata_client=Hf3fsLocalMetadataClient(),
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    )

    numel = 2 * tokens_per_page * layer_num * head_num * head_dim
    assert numel * dtype.itemsize == bytes_per_page

    num_page = 128
    values = [torch.randn((numel,)).to(dtype=dtype) for _ in tqdm(range(num_page))]

    iteration = 100

    for i in tqdm(range(iteration), desc="Benchmarking write (GB/s)"):
        keys = [f"{j}" for j in range(i * num_page, (i + 1) * num_page)]
        ok = hicache_hf3fs.batch_set(keys, values)
        assert ok

    read_keys, read_results = [], []
    for i in tqdm(range(iteration), desc="Benchmarking read (GB/s)"):
pansicheng's avatar
pansicheng committed
235
236
237
238
        keys = random.sample(
            list(hicache_hf3fs.metadata_client.rank_metadata.key_to_index.keys()),
            num_page,
        )
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        results = hicache_hf3fs.batch_get(keys)
        read_keys.extend(keys)
        read_results.extend(results)
        assert all([r is not None for r in results])

    for key, result in tqdm(zip(read_keys, read_results)):
        assert torch.allclose(values[int(key) % num_page], result, atol=1e-3)

    hicache_hf3fs.close()


def main():
    logging.basicConfig(level=logging.INFO)
    test()
    bench()
    allclose()


if __name__ == "__main__":
    main()