Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
20810691
Commit
20810691
authored
Jun 04, 2025
by
Tong WU
Committed by
LeiWang1999
Jun 04, 2025
Browse files
[CI] Add hadamard example to CI (#549)
* [CI] Add hadamard example to CI * Run yapf and ruff * Run yapf and ruff
parent
6cbec330
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
157 additions
and
0 deletions
+157
-0
examples/hadamard_transform/example_hadamard.py
examples/hadamard_transform/example_hadamard.py
+157
-0
No files found.
examples/hadamard_transform/example_hadamard.py
0 → 100644
View file @
20810691
import
tilelang
import
tilelang.language
as
T
from
tilelang.intrinsics
import
make_mma_swizzle_layout
import
math
import
argparse
import
torch
from
torch.nn
import
functional
as
F
import
scipy
def
is_pow_of_2
(
n
):
return
isinstance
(
n
,
int
)
and
n
>
0
and
(
n
&
(
n
-
1
))
==
0
def
hadamard
(
b
,
n
,
dtype
):
assert
is_pow_of_2
(
n
),
"n must be a power of 2"
assert
2
<=
n
<=
32768
,
"n must be in [2, 32768]"
elem_size
=
{
'float32'
:
4
,
'float16'
:
2
,
'bfloat16'
:
2
}[
dtype
]
logN
=
int
(
math
.
log2
(
n
))
threads
=
[
0
,
1
,
1
,
1
,
2
,
4
,
8
,
16
,
32
,
32
,
128
,
256
,
256
,
256
,
256
,
256
][
logN
]
thread_elem
=
n
//
threads
# Each thread is responsible for a chunk of elements
thread_round
=
int
(
math
.
log2
(
thread_elem
))
warps
=
1
if
threads
<=
32
else
threads
//
32
warp_round
=
int
(
math
.
log2
(
threads
/
warps
))
warp_size
=
threads
//
warps
block_round
=
int
(
math
.
log2
(
warps
))
exchange_round
=
n
*
elem_size
//
32768
if
n
*
elem_size
>
32768
else
1
# Suppose we use 32KB shared memory at most
thread_elem_in_smem
=
thread_elem
//
exchange_round
if
exchange_round
>
1
else
thread_elem
# debug log
# print(f'{threads=}, {thread_round=}')
# print(f'{warps=}, {warp_round=}, {warp_size=}')
# print(f'{block_round=}')
# print(f'{exchange_round=}')
@
T
.
macro
def
warp_shfl
(
local
:
T
.
Tensor
((
thread_elem
,),
dtype
),
buf
:
T
.
Tensor
((
thread_elem
,),
dtype
),
round
:
int
):
tx
=
T
.
get_thread_binding
(
0
)
for
i
in
T
.
serial
(
round
):
tx_stride
=
1
<<
i
another_tx
=
tx
^
tx_stride
sign
=
(
tx
>>
i
)
&
1
# get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
for
j
in
T
.
Pipelined
(
thread_elem
,
num_stages
=
1
):
buf
[
j
]
=
T
.
tvm_warp_shuffle
(
0xffffffff
,
# mask of all threads
local
[
j
],
another_tx
%
warp_size
,
warp_size
,
warp_size
)
local
[
j
]
=
T
.
if_then_else
(
sign
==
0
,
local
[
j
]
+
buf
[
j
],
buf
[
j
]
-
local
[
j
])
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
b
,
n
),
dtype
),
B
:
T
.
Tensor
((
b
,
n
),
dtype
)):
with
T
.
Kernel
(
b
,
threads
=
threads
)
as
bx
:
local
=
T
.
alloc_local
((
thread_elem
,),
dtype
)
shared
=
T
.
alloc_shared
((
threads
,
thread_elem_in_smem
),
dtype
)
T
.
annotate_layout
({
shared
:
make_mma_swizzle_layout
(
shared
)})
tx
=
T
.
get_thread_binding
(
0
)
# 1. Load from HBM to register
for
i
in
T
.
vectorized
(
thread_elem
):
local
[
i
]
=
A
[
bx
,
tx
*
thread_elem
+
i
]
# 2. Hadamard inside thread, n<=8
for
i
in
T
.
serial
(
thread_round
):
chunksize
=
1
<<
(
i
+
1
)
chunknum
=
thread_elem
//
chunksize
for
j
in
T
.
serial
(
chunknum
):
chunkbase
=
j
*
chunksize
for
k
in
T
.
serial
(
chunksize
//
2
):
local
[
chunkbase
+
k
]
=
local
[
chunkbase
+
k
]
+
local
[
chunkbase
+
k
+
chunksize
//
2
]
local
[
chunkbase
+
k
+
chunksize
//
2
]
=
local
[
chunkbase
+
k
]
-
2
*
local
[
chunkbase
+
k
+
chunksize
//
2
]
# 3. Hadamard inside warp, n<=512
# In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory
another_val
=
T
.
alloc_local
((
thread_elem
,),
dtype
)
warp_shfl
(
local
,
another_val
,
warp_round
)
# 4. Hadamard inside block, n<=32768
# Only exchange once for n<=8192, since shared mem can hold all elems
if
block_round
>
0
:
warp_id
=
tx
//
warp_size
lane_id
=
tx
%
warp_size
src_tx
=
warp_id
*
warp_size
+
lane_id
tgt_warp_id
=
tx
%
warps
tgt_lane_id
=
tx
//
warps
tgt_tx
=
tgt_warp_id
*
warp_size
+
tgt_lane_id
# 4.1 Write to smem, swap, read from smem
for
cur_round
in
T
.
serial
(
exchange_round
):
exchange_base
=
thread_elem_in_smem
*
cur_round
for
j
in
T
.
vectorized
(
thread_elem_in_smem
):
shared
[
src_tx
,
j
]
=
local
[
exchange_base
+
j
]
for
j
in
T
.
vectorized
(
thread_elem_in_smem
):
local
[
exchange_base
+
j
]
=
shared
[
tgt_tx
,
j
]
# 4.2 Warp shuffle
warp_shfl
(
local
,
another_val
,
block_round
)
# 4.3 Write to smem, swap, read from smem
for
cur_round
in
T
.
serial
(
exchange_round
):
exchange_base
=
thread_elem_in_smem
*
cur_round
for
j
in
T
.
vectorized
(
thread_elem_in_smem
):
shared
[
tgt_tx
,
j
]
=
local
[
exchange_base
+
j
]
for
j
in
T
.
vectorized
(
thread_elem_in_smem
):
local
[
exchange_base
+
j
]
=
shared
[
src_tx
,
j
]
# 5. Write back to HBM
for
i
in
T
.
vectorized
(
thread_elem
):
B
[
bx
,
tx
*
thread_elem
+
i
]
=
local
[
i
]
return
main
def
ref_program
(
x
:
torch
.
Tensor
):
assert
x
.
ndim
==
2
dim
=
x
.
shape
[
-
1
]
assert
is_pow_of_2
(
dim
)
return
F
.
linear
(
x
,
torch
.
tensor
(
scipy
.
linalg
.
hadamard
(
dim
,
dtype
=
float
),
dtype
=
x
.
dtype
,
device
=
x
.
device
))
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--batch'
,
type
=
int
,
default
=
64
,
help
=
'Batch size'
)
parser
.
add_argument
(
'--dim'
,
type
=
int
,
default
=
32768
,
help
=
'Dimension'
)
args
=
parser
.
parse_args
()
B
,
D
=
args
.
batch
,
args
.
dim
x
=
torch
.
randn
((
B
,
D
),
device
=
'cuda'
)
kernel
=
tilelang
.
compile
(
hadamard
(
B
,
D
,
'float32'
),
out_idx
=
1
)
y
=
kernel
(
x
)
y_ref
=
ref_program
(
x
)
torch
.
testing
.
assert_close
(
y
,
y_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
print
(
'All tests passed.'
)
profiler
=
kernel
.
get_profiler
(
tensor_supply_type
=
tilelang
.
TensorSupplyType
.
Auto
)
latency
=
profiler
.
do_bench
(
warmup
=
100
)
print
(
"Tile-lang: {:.2f} ms"
.
format
(
latency
))
if
__name__
==
'__main__'
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment