Unverified Commit ccec2463 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

add some csr addition unit tests (#110)

parent 5aa58b38
import torch
import random
from deepspeed.pt.deepspeed_csr_tensor import CSRTensor
def test_csr_addition_self():
row_count = 10
random.seed(1234)
x = torch.ones(1, 5)
for i in range(row_count - 1):
if random.random() > 0.75:
x = torch.cat([x, torch.ones(1, 5)])
else:
x = torch.cat([x, torch.zeros(1, 5)])
dense_x = x.clone()
cx = CSRTensor(x)
assert torch.all(dense_x == cx.to_dense())
cx.add(cx)
assert torch.all(dense_x + dense_x == cx.to_dense())
def test_csr_addition_different():
row_count = 10
random.seed(1234)
x = torch.ones(1, 5)
for i in range(row_count - 1):
if random.random() > 0.75:
x = torch.cat([x, torch.ones(1, 5)])
else:
x = torch.cat([x, torch.zeros(1, 5)])
dense_x = x.clone()
cx = CSRTensor(x)
y = torch.ones(1, 5)
for i in range(row_count - 1):
if random.random() > 0.75:
y = torch.cat([y, torch.ones(1, 5)])
else:
y = torch.cat([y, torch.zeros(1, 5)])
dense_y = y.clone()
cy = CSRTensor(y)
dense_sum = dense_x + dense_y
cx.add(cy)
assert torch.all(dense_sum == cx.to_dense())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment