test_csr.py 1.19 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import random
from deepspeed.pt.deepspeed_csr_tensor import CSRTensor


def test_csr_addition_self():
    row_count = 10
    random.seed(1234)

    x = torch.ones(1, 5)
    for i in range(row_count - 1):
        if random.random() > 0.75:
            x = torch.cat([x, torch.ones(1, 5)])
        else:
            x = torch.cat([x, torch.zeros(1, 5)])
    dense_x = x.clone()
    cx = CSRTensor(x)

    assert torch.all(dense_x == cx.to_dense())

    cx.add(cx)
    assert torch.all(dense_x + dense_x == cx.to_dense())


def test_csr_addition_different():
    row_count = 10
    random.seed(1234)

    x = torch.ones(1, 5)
    for i in range(row_count - 1):
        if random.random() > 0.75:
            x = torch.cat([x, torch.ones(1, 5)])
        else:
            x = torch.cat([x, torch.zeros(1, 5)])
    dense_x = x.clone()
    cx = CSRTensor(x)

    y = torch.ones(1, 5)
    for i in range(row_count - 1):
        if random.random() > 0.75:
            y = torch.cat([y, torch.ones(1, 5)])
        else:
            y = torch.cat([y, torch.zeros(1, 5)])
    dense_y = y.clone()
    cy = CSRTensor(y)

    dense_sum = dense_x + dense_y
    cx.add(cy)

    assert torch.all(dense_sum == cx.to_dense())