Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
caa6dd3f
Unverified
Commit
caa6dd3f
authored
Nov 24, 2025
by
Tong WU
Committed by
GitHub
Nov 24, 2025
Browse files
[Feat] Support warp reduce (#1316)
* [Feat] Support warp reduce * lint * add test * lint
parent
6c2162a9
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
259 additions
and
0 deletions
+259
-0
src/op/builtin.cc
src/op/builtin.cc
+25
-0
src/op/builtin.h
src/op/builtin.h
+25
-0
src/target/codegen_cuda.cc
src/target/codegen_cuda.cc
+10
-0
src/tl_templates/cuda/reduce.h
src/tl_templates/cuda/reduce.h
+31
-0
testing/python/language/test_tilelang_language_warp_reduce.py
...ing/python/language/test_tilelang_language_warp_reduce.py
+83
-0
tilelang/language/__init__.py
tilelang/language/__init__.py
+5
-0
tilelang/language/reduce.py
tilelang/language/reduce.py
+80
-0
No files found.
src/op/builtin.cc
View file @
caa6dd3f
...
@@ -341,5 +341,30 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
...
@@ -341,5 +341,30 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
warp_reduce_sum
)
.
set_num_inputs
(
1
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
warp_reduce_max
)
.
set_num_inputs
(
1
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
warp_reduce_min
)
.
set_num_inputs
(
1
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
warp_reduce_bitand
)
.
set_num_inputs
(
1
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
TIR_DEFINE_TL_BUILTIN
(
warp_reduce_bitor
)
.
set_num_inputs
(
1
)
.
set_attr
<
TCallEffectKind
>
(
"TCallEffectKind"
,
Integer
(
CallEffectKind
::
kOpaque
));
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
src/op/builtin.h
View file @
caa6dd3f
...
@@ -571,6 +571,31 @@ TVM_DLL const Op &device_assert();
...
@@ -571,6 +571,31 @@ TVM_DLL const Op &device_assert();
*/
*/
TVM_DLL
const
Op
&
device_assert_with_msg
();
TVM_DLL
const
Op
&
device_assert_with_msg
();
/*!
* \brief tilelang intrinsic for warp reduction sum.
*/
TVM_DLL
const
Op
&
warp_reduce_sum
();
/*!
* \brief tilelang intrinsic for warp reduction max.
*/
TVM_DLL
const
Op
&
warp_reduce_max
();
/*!
* \brief tilelang intrinsic for warp reduction min.
*/
TVM_DLL
const
Op
&
warp_reduce_min
();
/*!
* \brief tilelang intrinsic for warp reduction bitand.
*/
TVM_DLL
const
Op
&
warp_reduce_bitand
();
/*!
* \brief tilelang intrinsic for warp reduction bitor.
*/
TVM_DLL
const
Op
&
warp_reduce_bitor
();
}
// namespace tl
}
// namespace tl
}
// namespace tvm
}
// namespace tvm
...
...
src/target/codegen_cuda.cc
View file @
caa6dd3f
...
@@ -2609,6 +2609,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
...
@@ -2609,6 +2609,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std
::
string
func_name
=
math_func
(
op
->
dtype
,
"fdiv"
,
rounding_mode
);
std
::
string
func_name
=
math_func
(
op
->
dtype
,
"fdiv"
,
rounding_mode
);
os
<<
func_name
<<
"("
<<
PrintExpr
(
op
->
args
[
0
])
<<
", "
os
<<
func_name
<<
"("
<<
PrintExpr
(
op
->
args
[
0
])
<<
", "
<<
PrintExpr
(
op
->
args
[
1
])
<<
")"
;
<<
PrintExpr
(
op
->
args
[
1
])
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
warp_reduce_sum
()))
{
os
<<
"tl::warp_reduce_sum("
<<
PrintExpr
(
op
->
args
[
0
])
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
warp_reduce_max
()))
{
os
<<
"tl::warp_reduce_max("
<<
PrintExpr
(
op
->
args
[
0
])
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
warp_reduce_min
()))
{
os
<<
"tl::warp_reduce_min("
<<
PrintExpr
(
op
->
args
[
0
])
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
warp_reduce_bitand
()))
{
os
<<
"tl::warp_reduce_bitand("
<<
PrintExpr
(
op
->
args
[
0
])
<<
")"
;
}
else
if
(
op
->
op
.
same_as
(
tl
::
warp_reduce_bitor
()))
{
os
<<
"tl::warp_reduce_bitor("
<<
PrintExpr
(
op
->
args
[
0
])
<<
")"
;
}
else
{
}
else
{
CodeGenC
::
VisitExpr_
(
op
,
os
);
CodeGenC
::
VisitExpr_
(
op
,
os
);
}
}
...
...
src/tl_templates/cuda/reduce.h
View file @
caa6dd3f
...
@@ -250,4 +250,35 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
...
@@ -250,4 +250,35 @@ template <int threads, int Axis = 0, bool reverse = false> struct CumSum2D {
}
}
};
};
template
<
typename
T
,
typename
ReduceOp
>
TL_DEVICE
T
warp_reduce
(
T
value
,
ReduceOp
op
)
{
constexpr
uint32_t
mask
=
0xffffffff
;
value
=
op
(
value
,
__shfl_xor_sync
(
mask
,
value
,
16
));
value
=
op
(
value
,
__shfl_xor_sync
(
mask
,
value
,
8
));
value
=
op
(
value
,
__shfl_xor_sync
(
mask
,
value
,
4
));
value
=
op
(
value
,
__shfl_xor_sync
(
mask
,
value
,
2
));
value
=
op
(
value
,
__shfl_xor_sync
(
mask
,
value
,
1
));
return
value
;
}
template
<
typename
T
>
TL_DEVICE
T
warp_reduce_sum
(
T
value
)
{
return
warp_reduce
<
T
>
(
value
,
SumOp
());
}
template
<
typename
T
>
TL_DEVICE
T
warp_reduce_max
(
T
value
)
{
return
warp_reduce
<
T
>
(
value
,
MaxOp
());
}
template
<
typename
T
>
TL_DEVICE
T
warp_reduce_min
(
T
value
)
{
return
warp_reduce
<
T
>
(
value
,
MinOp
());
}
template
<
typename
T
>
TL_DEVICE
T
warp_reduce_bitand
(
T
value
)
{
return
warp_reduce
<
T
>
(
value
,
BitAndOp
());
}
template
<
typename
T
>
TL_DEVICE
T
warp_reduce_bitor
(
T
value
)
{
return
warp_reduce
<
T
>
(
value
,
BitOrOp
());
}
}
// namespace tl
}
// namespace tl
testing/python/language/test_tilelang_language_warp_reduce.py
0 → 100644
View file @
caa6dd3f
import
torch
import
tilelang
import
tilelang.testing
import
tilelang.language
as
T
@
tilelang
.
jit
def
get_kernel
(
reduce_op
:
str
,
dtype
:
str
):
assert
reduce_op
in
[
"sum"
,
"max"
,
"min"
,
"bitand"
,
"bitor"
]
@
T
.
prim_func
def
main
(
x
:
T
.
Tensor
((
32
),
dtype
)):
with
T
.
Kernel
(
1
,
threads
=
32
):
tx
=
T
.
get_thread_binding
(
0
)
local_val
=
T
.
alloc_local
([
1
],
dtype
)
local_val
[
0
]
=
x
[
tx
]
reduced_val
=
T
.
alloc_local
([
1
],
dtype
)
if
reduce_op
==
"sum"
:
reduced_val
[
0
]
=
T
.
warp_reduce_sum
(
local_val
[
0
])
elif
reduce_op
==
"max"
:
reduced_val
[
0
]
=
T
.
warp_reduce_max
(
local_val
[
0
])
elif
reduce_op
==
"min"
:
reduced_val
[
0
]
=
T
.
warp_reduce_min
(
local_val
[
0
])
elif
reduce_op
==
"bitand"
:
reduced_val
[
0
]
=
T
.
warp_reduce_bitand
(
local_val
[
0
])
elif
reduce_op
==
"bitor"
:
reduced_val
[
0
]
=
T
.
warp_reduce_bitor
(
local_val
[
0
])
x
[
tx
]
=
reduced_val
[
0
]
return
main
def
test_warp_reduce_sum
():
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
kernel
=
get_kernel
(
'sum'
,
'float32'
)
ref
=
torch
.
full_like
(
a
,
a
.
sum
())
kernel
(
a
)
torch
.
testing
.
assert_close
(
a
,
ref
)
def
test_warp_reduce_max
():
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
kernel
=
get_kernel
(
"max"
,
'float32'
)
print
(
kernel
.
get_kernel_source
())
ref
=
torch
.
full_like
(
a
,
a
.
max
())
kernel
(
a
)
torch
.
testing
.
assert_close
(
a
,
ref
)
def
test_warp_reduce_min
():
a
=
torch
.
randn
((
32
,),
dtype
=
torch
.
float32
,
device
=
'cuda'
)
kernel
=
get_kernel
(
"min"
,
'float32'
)
ref
=
torch
.
full_like
(
a
,
a
.
min
())
kernel
(
a
)
torch
.
testing
.
assert_close
(
a
,
ref
)
def
test_warp_reduce_bitand
():
a
=
torch
.
randint
(
0
,
100
,
size
=
(
32
,),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
kernel
=
get_kernel
(
"bitand"
,
'int32'
)
ref_val
=
a
[
0
]
for
i
in
range
(
1
,
a
.
shape
[
0
]):
ref_val
=
ref_val
&
a
[
i
]
ref
=
torch
.
full_like
(
a
,
ref_val
)
kernel
(
a
)
torch
.
testing
.
assert_close
(
a
,
ref
)
def
test_warp_reduce_bitor
():
a
=
torch
.
randint
(
0
,
100
,
size
=
(
32
,),
dtype
=
torch
.
int32
,
device
=
'cuda'
)
kernel
=
get_kernel
(
"bitor"
,
'int32'
)
ref_val
=
a
[
0
]
for
i
in
range
(
1
,
a
.
shape
[
0
]):
ref_val
=
ref_val
|
a
[
i
]
ref
=
torch
.
full_like
(
a
,
ref_val
)
kernel
(
a
)
torch
.
testing
.
assert_close
(
a
,
ref
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang/language/__init__.py
View file @
caa6dd3f
...
@@ -65,6 +65,11 @@ from .reduce import (
...
@@ -65,6 +65,11 @@ from .reduce import (
reduce_bitxor
,
# noqa: F401
reduce_bitxor
,
# noqa: F401
cumsum
,
# noqa: F401
cumsum
,
# noqa: F401
finalize_reducer
,
# noqa: F401
finalize_reducer
,
# noqa: F401
warp_reduce_sum
,
# noqa: F401
warp_reduce_max
,
# noqa: F401
warp_reduce_min
,
# noqa: F401
warp_reduce_bitand
,
# noqa: F401
warp_reduce_bitor
,
# noqa: F401
)
)
from
.print
import
print
,
device_assert
# noqa: F401
from
.print
import
print
,
device_assert
# noqa: F401
from
.customize
import
(
from
.customize
import
(
...
...
tilelang/language/reduce.py
View file @
caa6dd3f
...
@@ -325,3 +325,83 @@ def finalize_reducer(reducer: tir.Buffer):
...
@@ -325,3 +325,83 @@ def finalize_reducer(reducer: tir.Buffer):
tir
.
op
.
Op
.
get
(
"tl.finalize_reducer"
),
tir
.
op
.
Op
.
get
(
"tl.finalize_reducer"
),
reducer
.
access_ptr
(
"w"
),
reducer
.
access_ptr
(
"w"
),
)
)
def
warp_reduce_sum
(
value
:
tir
.
PrimExpr
):
"""Perform warp reduction sum on a register value.
This function reduces a value across all threads in a warp using shuffle operations.
Each thread provides a register `value`, and after the reduction, all threads
will have the sum of all values across the warp.
Args:
value (tir.PrimExpr): The input register value to reduce
Returns:
tir.PrimExpr: The reduced sum value (same on all threads in the warp)
"""
return
tir
.
call_intrin
(
value
.
dtype
,
tir
.
op
.
Op
.
get
(
"tl.warp_reduce_sum"
),
value
)
def
warp_reduce_max
(
value
:
tir
.
PrimExpr
):
"""Perform warp reduction max on a register value.
This function reduces a value across all threads in a warp using shuffle operations.
Each thread provides a register `value`, and after the reduction, all threads
will have the max of all values across the warp.
Args:
value (tir.PrimExpr): The input register value to reduce
Returns:
tir.PrimExpr: The reduced max value (same on all threads in the warp)
"""
return
tir
.
call_intrin
(
value
.
dtype
,
tir
.
op
.
Op
.
get
(
"tl.warp_reduce_max"
),
value
)
def
warp_reduce_min
(
value
:
tir
.
PrimExpr
):
"""Perform warp reduction min on a register value.
This function reduces a value across all threads in a warp using shuffle operations.
Each thread provides a register `value`, and after the reduction, all threads
will have the min of all values across the warp.
Args:
value (tir.PrimExpr): The input register value to reduce
Returns:
tir.PrimExpr: The reduced min value (same on all threads in the warp)
"""
return
tir
.
call_intrin
(
value
.
dtype
,
tir
.
op
.
Op
.
get
(
"tl.warp_reduce_min"
),
value
)
def
warp_reduce_bitand
(
value
:
tir
.
PrimExpr
):
"""Perform warp reduction bitwise-and on a register value.
This function reduces a value across all threads in a warp using shuffle operations.
Each thread provides a register `value`, and after the reduction, all threads
will have the bitwise-and of all values across the warp.
Args:
value (tir.PrimExpr): The input register value to reduce
Returns:
tir.PrimExpr: The reduced bitwise-and value (same on all threads in the warp)
"""
return
tir
.
call_intrin
(
value
.
dtype
,
tir
.
op
.
Op
.
get
(
"tl.warp_reduce_bitand"
),
value
)
def
warp_reduce_bitor
(
value
:
tir
.
PrimExpr
):
"""Perform warp reduction bitwise-or on a register value.
This function reduces a value across all threads in a warp using shuffle operations.
Each thread provides a register `value`, and after the reduction, all threads
will have the bitwise-or of all values across the warp.
Args:
value (tir.PrimExpr): The input register value to reduce
Returns:
tir.PrimExpr: The reduced bitwise-or value (same on all threads in the warp)
"""
return
tir
.
call_intrin
(
value
.
dtype
,
tir
.
op
.
Op
.
get
(
"tl.warp_reduce_bitor"
),
value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment