Unverified Commit bab57f23 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[CI] Speed up sparse tensor core test via vectorized generating sparse data (#1009)

parent 340bfc50
......@@ -66,21 +66,14 @@ def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.")
full_tensor = torch.randn(shape, dtype=dtype, device=device)
mask = torch.zeros_like(full_tensor, dtype=torch.bool)
group_count = shape[-1] // 4
group_shape = shape[:-1] + (group_count, 4)
reshaped = full_tensor.view(*group_shape)
for idx in range(reshaped.numel() // 4):
flat_idx = torch.randint(0, 4, (2,), dtype=torch.int64)
while flat_idx[0] == flat_idx[1]:
flat_idx[1] = torch.randint(0, 4, (1,), dtype=torch.int64)
i = idx // group_count
j = idx % group_count
mask.view(*group_shape)[i, j, flat_idx[0]] = True
mask.view(*group_shape)[i, j, flat_idx[1]] = True
rand_vals = torch.rand(group_shape, device=device)
topk_indices = rand_vals.topk(k=2, dim=-1).indices
mask = torch.zeros(group_shape, dtype=torch.bool, device=device)
mask.scatter_(-1, topk_indices, True)
mask = mask.view(shape)
sparse_tensor = full_tensor * mask
return sparse_tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment