Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bab57f23
Unverified
Commit
bab57f23
authored
Oct 13, 2025
by
Lei Wang
Committed by
GitHub
Oct 13, 2025
Browse files
[CI] Speed up sparse tensor core test via vectorized generating sparse data (#1009)
parent
340bfc50
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
12 deletions
+5
-12
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
...s/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
+5
-12
No files found.
examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
View file @
bab57f23
...
@@ -66,21 +66,14 @@ def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
...
@@ -66,21 +66,14 @@ def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'):
raise
ValueError
(
"Last dimension must be divisible by 4 for 2:4 sparsity."
)
raise
ValueError
(
"Last dimension must be divisible by 4 for 2:4 sparsity."
)
full_tensor
=
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
device
)
full_tensor
=
torch
.
randn
(
shape
,
dtype
=
dtype
,
device
=
device
)
mask
=
torch
.
zeros_like
(
full_tensor
,
dtype
=
torch
.
bool
)
group_count
=
shape
[
-
1
]
//
4
group_count
=
shape
[
-
1
]
//
4
group_shape
=
shape
[:
-
1
]
+
(
group_count
,
4
)
group_shape
=
shape
[:
-
1
]
+
(
group_count
,
4
)
reshaped
=
full_tensor
.
view
(
*
group_shape
)
rand_vals
=
torch
.
rand
(
group_shape
,
device
=
device
)
topk_indices
=
rand_vals
.
topk
(
k
=
2
,
dim
=-
1
).
indices
for
idx
in
range
(
reshaped
.
numel
()
//
4
):
mask
=
torch
.
zeros
(
group_shape
,
dtype
=
torch
.
bool
,
device
=
device
)
flat_idx
=
torch
.
randint
(
0
,
4
,
(
2
,),
dtype
=
torch
.
int64
)
mask
.
scatter_
(
-
1
,
topk_indices
,
True
)
while
flat_idx
[
0
]
==
flat_idx
[
1
]:
mask
=
mask
.
view
(
shape
)
flat_idx
[
1
]
=
torch
.
randint
(
0
,
4
,
(
1
,),
dtype
=
torch
.
int64
)
i
=
idx
//
group_count
j
=
idx
%
group_count
mask
.
view
(
*
group_shape
)[
i
,
j
,
flat_idx
[
0
]]
=
True
mask
.
view
(
*
group_shape
)[
i
,
j
,
flat_idx
[
1
]]
=
True
sparse_tensor
=
full_tensor
*
mask
sparse_tensor
=
full_tensor
*
mask
return
sparse_tensor
return
sparse_tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment