test_cache_embedding.py 14.4 KB
Newer Older
1
import random
2
from typing import List
3

4
5
import numpy as np
import pytest
Jiarui Fang's avatar
Jiarui Fang committed
6
7
import torch

8
import colossalai
9
from colossalai.legacy.nn.parallel.layers import (
10
11
12
13
14
15
16
    CachedEmbeddingBag,
    CachedParamMgr,
    EvictionStrategy,
    ParallelCachedEmbeddingBag,
    ParallelCachedEmbeddingBagTablewise,
    TablewiseEmbeddingBagConfig,
)
17
18
from colossalai.legacy.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.tensor import ColoTensor
19
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
20

21
NUM_EMBED, EMBED_DIM = 10, 8
22
23
24
BATCH_SIZE = 8


25
26
27
28
29
30
31
32
33
def set_seed(seed):
    """
    To achieve reproducible results, it's necessary to fix random seeds
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


34
35
36
37
38
39
40
41
42
43
44
45
46
47
def synthesize_1d_sparse_feature(
    batch_size,
    num_embed,
    device,
):
    indices_in_batch = batch_size * 2
    indices = torch.randint(low=0, high=num_embed, size=(indices_in_batch,), device=device, dtype=torch.long)
    offsets = torch.from_numpy(
        np.array([
            0, *np.sort(np.random.randint(low=0, high=indices_in_batch, size=(indices_in_batch - 1,))), indices_in_batch
        ])).to(device).long()
    return indices, offsets


48
@pytest.mark.skip
49
@clear_cache_before_run()
50
51
52
def test_cachemgr():
    model = torch.nn.EmbeddingBag(10000, 128)
    # 10 chunks, 5 in cuda
53
    mgr = CachedParamMgr(model.weight.detach(), 5)
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    assert mgr.cuda_row_num == 5

    mgr._admit(1)
    assert not mgr._chunk_in_cuda(2)
    assert mgr._chunk_in_cuda(1)

    # print(mgr.cached_chunk_table)
    mgr._admit(8)

    # now 3 chunk is available
    assert mgr.cuda_available_chunk_num == 3

    mgr._evict()
    assert mgr.cuda_available_chunk_num == 4

    mgr._prepare_rows_on_cuda(torch.tensor([9, 6, 5], dtype=torch.long, device=0))
    mgr._prepare_rows_on_cuda(torch.tensor([3, 4, 5], dtype=torch.long, device=0))
    # print(mgr.cached_chunk_table)
    # mgr.print_comm_stats()

    mgr.flush()
    assert mgr.cuda_available_chunk_num == 5


78
@clear_cache_before_run()
79
80
81
82
83
def test_reorder_with_freq():
    num_embed = 100
    chunk_size = 1
    num_chunk = 5

84
85
    idx_map = torch.randint(10000, size=(num_embed,))
    sorted_idx = torch.argsort(idx_map, descending=True).tolist()
86
87
88
89
90
91
    chunkid, offset_in_chunk = [], []
    for i in range(num_embed):
        idx = sorted_idx.index(i)
        chunkid.append(idx // chunk_size)
        offset_in_chunk.append(idx % chunk_size)

92
93
94
    dev = torch.device('cuda')
    chunkid = torch.tensor(chunkid, dtype=torch.long, device=dev)
    offset_in_chunk = torch.tensor(offset_in_chunk, dtype=torch.long, device=dev)
95
96

    weight = torch.rand(num_embed, 2)
97
    mgr = CachedParamMgr(weight, num_chunk)
98
99
100

    mgr.reorder(idx_map)

101
    indices = mgr.idx_map.index_select(0, torch.arange(num_embed, dtype=torch.long, device=dev))
102
103
104
105
106
107
108
    mgr_chunk_id = torch.div(indices, chunk_size, rounding_mode='floor')
    mgr_offsets = torch.remainder(indices, chunk_size)
    assert torch.allclose(chunkid, mgr_chunk_id), f"chunk id: {chunkid}, mgr: {mgr_chunk_id}"
    assert torch.allclose(offset_in_chunk, mgr_offsets), \
        f"offset in chunk: {offset_in_chunk}, mgr: {mgr_offsets}"


109
110
@clear_cache_before_run()
@parameterize('use_LFU', [True, False])
111
def test_freq_aware_embed(use_LFU: bool):
112
    device = torch.device('cuda', 0)
113
    evict_strategy = EvictionStrategy.LFU if use_LFU else EvictionStrategy.DATASET
114
115
116
117
118
119
120
    model = CachedEmbeddingBag(NUM_EMBED,
                               EMBED_DIM,
                               mode='mean',
                               include_last_offset=True,
                               cache_ratio=min(BATCH_SIZE * 2 / NUM_EMBED, 1.0),
                               ids_freq_mapping=None,
                               evict_strategy=evict_strategy).to(device)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    assert model.weight.shape[0] == NUM_EMBED
    ref_model = torch.nn.EmbeddingBag.from_pretrained(model.weight.detach().to(device),
                                                      mode='mean',
                                                      include_last_offset=True,
                                                      freeze=False)

    assert torch.allclose(ref_model.weight.detach(), model.weight.detach().to(device))

    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)

    for i in range(5):
        indices, offsets = synthesize_1d_sparse_feature(BATCH_SIZE, NUM_EMBED, device)
        res = model(indices, offsets)
        ref_res = ref_model(indices, offsets)
        assert torch.allclose(res, ref_res), f"model result: {res}, reference: {ref_res}"

        grad = torch.rand_like(res)
        # comparing gradient here is nontrivial
        res.backward(grad)
        ref_res.backward(grad)
        optimizer.step()
        optimizer.zero_grad()

        ref_optimizer.step()
        ref_optimizer.zero_grad()

    model.cache_weight_mgr.flush()
    model_weight = model.weight.detach().to(device)
    ref_weight = ref_model.weight.detach()
    assert torch.allclose(model_weight, ref_weight), \
        f"model weight: {model_weight[10:18, :8]}, reference: {ref_weight[10:18, :8]}"

155

156
157
@clear_cache_before_run()
@parameterize('init_freq', [True, False])
158
def test_lfu_strategy(init_freq: bool):
CsRic's avatar
CsRic committed
159
    # minimal test to check behavior
160
161
162
163
164
165
166
167
    Bag = CachedEmbeddingBag(5,
                             5,
                             cache_ratio=3 / 5,
                             buffer_size=0,
                             pin_weight=True,
                             ids_freq_mapping=[4, 2, 1, 3, 1] if init_freq else None,
                             warmup_ratio=1.0,
                             evict_strategy=EvictionStrategy.LFU)
168
169
170

    # print('cached_idx_map: ', Bag.cache_weight_mgr.cached_idx_map)
    offsets = torch.tensor([0], device="cuda:0")
CsRic's avatar
CsRic committed
171
172

    # prepare frequency learning info:
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    Bag.forward(torch.tensor([2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([1, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0], device="cuda:0"), offsets)
CsRic's avatar
CsRic committed
188
189

    # check strategy
190
191
192
193
194
195
196
    Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([0, 1, 2], device="cuda:0"), offsets)
    Bag.forward(torch.tensor([3], device="cuda:0"), offsets)    # miss, evict 1
    Bag.forward(torch.tensor([2], device="cuda:0"), offsets)    # hit
    Bag.forward(torch.tensor([4], device="cuda:0"), offsets)    # miss, evict 3
    Bag.forward(torch.tensor([2], device="cuda:0"), offsets)    # hit
    Bag.forward(torch.tensor([0], device="cuda:0"), offsets)    # hit
CsRic's avatar
CsRic committed
197
198
199

    assert torch.allclose(torch.Tensor(Bag.cache_weight_mgr.num_hits_history[-6:]), torch.Tensor([3, 0, 1, 0, 1, 1])), \
        "LFU strategy behavior failed"
200
201


202
203
204
205
206
207
208
209
210
def gather_tensor(tensor, rank, world_size):
    gather_list = []
    if rank == 0:
        gather_list = [torch.empty_like(tensor) for _ in range(world_size)]

    torch.distributed.gather(tensor, gather_list, dst=0)
    return gather_list


211
212
213
214
215
216
217
def run_parallel_freq_aware_embed_tablewise(rank, world_size):
    if world_size != 2:
        return
    device = torch.device('cuda', torch.cuda.current_device())

    # initialize weight
    # 3 feature tables. idx: 0~5, 6~10, 11~17
218
    weight_tables = torch.rand(18, 5)
219
220
221
    weight_table1 = weight_tables[0:6]
    weight_table2 = weight_tables[6:11]
    weight_table3 = weight_tables[11:18]
222
    embedding_bag_config_list: List[TablewiseEmbeddingBagConfig] = []
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    embedding_bag_config_list.append(
        TablewiseEmbeddingBagConfig(num_embeddings=6,
                                    cuda_row_num=4,
                                    assigned_rank=0,
                                    initial_weight=weight_table1.clone().detach().cpu()))
    embedding_bag_config_list.append(
        TablewiseEmbeddingBagConfig(num_embeddings=5,
                                    cuda_row_num=4,
                                    assigned_rank=0,
                                    initial_weight=weight_table2.clone().detach().cpu()))
    embedding_bag_config_list.append(
        TablewiseEmbeddingBagConfig(num_embeddings=7,
                                    cuda_row_num=4,
                                    assigned_rank=1,
                                    initial_weight=weight_table3.clone().detach().cpu()))
238
    if rank == 0:
239
        _weight = torch.cat([weight_table1, weight_table2], 0)
240
241
    else:
        _weight = weight_table3
242
    model = ParallelCachedEmbeddingBagTablewise(
243
244
        embedding_bag_config_list,
        embedding_dim=5,
245
246
        _weight=_weight,
        include_last_offset=True,
247
        cache_ratio=0.5,
248
        buffer_size=0,
249
250
        evict_strategy=EvictionStrategy.LFU,
    )
251
    # explain
252
253
254
255
256
    '''
    batch       feature 1       feature 2       feature 3
    input0      [1,2,3]         [6,7]           []
    input1      []              [9]             [13,15]
    input2      [1,5]           [6,8]           [11]
257
                  ↑               ↑               ↑
258
259
260
261
                rank 0          rank 0          rank 1
    in KJT format
    '''
    res = model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
262
263
                torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device),
                already_split_along_rank=False)
264
265
266
267
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
    rand_grad = torch.rand(3, 5 * 3, dtype=res.dtype, device=res.device)
    if rank == 0:
        fake_grad = rand_grad[0:2]
268
    else:
269
270
271
272
273
        fake_grad = rand_grad[2:]
    res.backward(fake_grad)
    optimizer.step()
    optimizer.zero_grad()

274
    # check correctness
275
    if rank == 0:
276
        ref_model = torch.nn.EmbeddingBag.from_pretrained(weight_tables.detach().clone(),
277
278
279
                                                          include_last_offset=True,
                                                          freeze=False).to(device)
        ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-2)
280
        ref_fake_grad = torch.cat(rand_grad.split(5, 1), 0)
281
282
283
        ref_res = ref_model(torch.tensor([1, 2, 3, 1, 5, 6, 7, 9, 6, 8, 13, 15, 11], device=device),
                            torch.tensor([0, 3, 3, 5, 7, 8, 10, 10, 12, 13], device=device))
        ref_res.backward(ref_fake_grad)
284
285
        ref_optimizer.step()
        ref_optimizer.zero_grad()
286

287
288
289
290
        model.cache_weight_mgr.flush()
        recover_weight = model.cache_weight_mgr.weight.to(device)
        ref_weight = ref_model.weight.detach()[:11]
        assert torch.allclose(recover_weight, ref_weight), f"{recover_weight - ref_weight}"
291

292

293
def run_parallel_freq_aware_embed_columnwise(rank, world_size):
294
295
296
297
298
299
300
301
    device = torch.device('cuda', torch.cuda.current_device())

    num_embed = 100
    embed_dim = 16
    batch_size = 4

    set_seed(4321)
    weight = torch.rand(num_embed, embed_dim)
302
    coloweight = ColoTensor(weight.clone().detach().cpu(), spec=None)
303

304
305
    # initialize the tensor spec for the embedding weight parameter,
    # which is an ColoParameter.
306
    coloweight.set_process_group(ProcessGroup(tp_degree=world_size))
307
308
    coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D))

309
    model = ParallelCachedEmbeddingBag.from_pretrained(
310
311
312
        coloweight,
        include_last_offset=True,
        freeze=False,
313
        cache_ratio=batch_size * 2 / num_embed,
314
    )
315

316
    assert model.cache_weight_mgr.weight.device.type == 'cpu'
317
318
    assert model.cache_weight_mgr.cuda_cached_weight.requires_grad
    weight_in_rank = torch.tensor_split(weight, world_size, -1)[rank]
319
320
321
    print(f"model weight: {model.cache_weight_mgr.weight.shape}, ref weight: {weight_in_rank.shape}")
    assert torch.allclose(weight_in_rank,
                          model.cache_weight_mgr.weight.detach()), f"{weight_in_rank - model.cache_weight_mgr.weight}"
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    if rank == 0:
        ref_model = torch.nn.EmbeddingBag.from_pretrained(weight.detach().clone(),
                                                          include_last_offset=True,
                                                          freeze=False).to(device)
        ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=1e-3)

    set_seed(4321)
    for i in range(5):
        indices, offsets = synthesize_1d_sparse_feature(batch_size, num_embed, device)
        res = model(indices, offsets)

        grad = torch.rand(batch_size * 2, embed_dim, dtype=res.dtype, device=res.device)
        grad_in_rank = torch.tensor_split(grad, world_size, 0)[rank]
        res.backward(grad_in_rank)

        optimizer.step()
        optimizer.zero_grad()

        res_list = gather_tensor(res.detach(), rank, world_size)

        if rank == 0:
            ref_res = ref_model(indices, offsets)
            recover_res = torch.cat(res_list, dim=0)

            assert torch.allclose(ref_res, recover_res)

            ref_res.backward(grad)
            ref_optimizer.step()
            ref_optimizer.zero_grad()

    model.cache_weight_mgr.flush()
356
    weight_list = gather_tensor(model.cache_weight_mgr.weight.detach().cuda(), rank, world_size)
357
358
359
360
361
362
    if rank == 0:
        recover_weight = torch.cat(weight_list, dim=1)
        assert torch.allclose(recover_weight, ref_model.weight.detach()), f"{recover_weight - ref_model.weight}"


def run_dist(rank, world_size, port):
363
    colossalai.legacy.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
364
365
    # run_parallel_freq_aware_embed_columnwise(rank, world_size)
    run_parallel_freq_aware_embed_tablewise(rank, world_size)
366
367
368
369
370
371


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_parallel_freq_aware_embed(world_size):
372
    spawn(run_dist, world_size)
373
374


375
if __name__ == '__main__':
376
377
    # test_freq_aware_embed(True)
    test_parallel_freq_aware_embed(2)
378
    # test_lfu_strategy(False)