Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bc37ea69
Unverified
Commit
bc37ea69
authored
Oct 20, 2025
by
Lei Wang
Committed by
GitHub
Oct 20, 2025
Browse files
[Language] Efficient `T.reduce_` with shared memory input/output (#1080)
* Support reduce ss * lint fix * test fix * lint fix
parent
a7730272
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
576 additions
and
369 deletions
+576
-369
src/op/reduce.cc
src/op/reduce.cc
+237
-187
src/tl_templates/cuda/reduce.h
src/tl_templates/cuda/reduce.h
+47
-0
src/tl_templates/hip/reduce.h
src/tl_templates/hip/reduce.h
+65
-0
testing/python/language/test_tilelang_language_reduce.py
testing/python/language/test_tilelang_language_reduce.py
+226
-0
testing/python/language/test_tilelang_language_reduce_max.py
testing/python/language/test_tilelang_language_reduce_max.py
+0
-92
testing/python/language/test_tilelang_language_reduce_sum.py
testing/python/language/test_tilelang_language_reduce_sum.py
+0
-89
tilelang/jit/adapter/wrapper.py
tilelang/jit/adapter/wrapper.py
+1
-1
No files found.
src/op/reduce.cc
View file @
bc37ea69
...
@@ -175,207 +175,257 @@ std::string ReduceOpNode::MakeCodegenReducer() const {
...
@@ -175,207 +175,257 @@ std::string ReduceOpNode::MakeCodegenReducer() const {
* @return Stmt Lowered TIR statement implementing the reduction.
* @return Stmt Lowered TIR statement implementing the reduction.
*/
*/
Stmt
ReduceOpNode
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
Stmt
ReduceOpNode
::
Lower
(
const
LowerArgs
&
T
,
arith
::
Analyzer
*
analyzer
)
const
{
ICHECK
(
this
->
src
.
scope
()
==
"local.fragment"
&&
auto
get_buffer
=
[
&
](
const
Buffer
&
buf
)
{
this
->
dst
.
scope
()
==
"local.fragment"
)
if
(
T
.
buffer_remap
.
count
(
buf
))
<<
"Reduce for shared memory not implemented."
;
return
T
.
buffer_remap
[
buf
];
auto
src_buffer
=
T
.
buffer_remap
[
this
->
src
];
return
buf
;
auto
dst_buffer
=
T
.
buffer_remap
[
this
->
dst
];
};
Fragment
src_layout
=
T
.
layout_map
[
this
->
src
].
as
<
Fragment
>
().
value
();
Fragment
dst_layout
=
T
.
layout_map
[
this
->
dst
].
as
<
Fragment
>
().
value
();
auto
src_scope
=
this
->
src
.
scope
();
size_t
src_dim
=
src_layout
->
InputDim
();
auto
dst_scope
=
this
->
dst
.
scope
();
size_t
dst_dim
=
dst_layout
->
InputDim
();
if
(
src_scope
==
"local.fragment"
&&
dst_scope
==
"local.fragment"
)
{
bool
is_1d_reduce
=
src_dim
==
dst_dim
&&
dst_dim
==
1
;
Buffer
src_buffer
=
get_buffer
(
this
->
src
);
Buffer
dst_buffer
=
get_buffer
(
this
->
dst
);
if
(
is_1d_reduce
)
{
Fragment
src_layout
=
T
.
layout_map
[
this
->
src
].
as
<
Fragment
>
().
value
();
ICHECK
(
is_one
(
dst_layout
->
OutputShape
().
back
()))
Fragment
dst_layout
=
T
.
layout_map
[
this
->
dst
].
as
<
Fragment
>
().
value
();
<<
"Reduce for scalar not implemented."
;
size_t
src_dim
=
src_layout
->
InputDim
();
}
else
{
size_t
dst_dim
=
dst_layout
->
InputDim
();
ICHECK
(
src_dim
==
dst_dim
+
1
)
<<
"Reduce dimension mismatch."
;
}
bool
is_1d_reduce
=
src_dim
==
dst_dim
&&
dst_dim
==
1
;
if
(
is_1d_reduce
)
{
ICHECK
(
is_one
(
dst_layout
->
OutputShape
().
back
()))
<<
"Reduce for scalar not implemented."
;
}
else
{
ICHECK_EQ
(
src_dim
,
dst_dim
+
1
)
<<
"Reduce dimension mismatch."
;
}
Array
<
IterVar
>
dst_vars
;
Array
<
IterVar
>
dst_vars
;
for
(
size_t
i
=
0
;
i
<
dst_dim
;
i
++
)
{
for
(
size_t
i
=
0
;
i
<
dst_dim
;
++
i
)
{
Var
var
=
Var
(
std
::
string
{
char
(
'i'
+
i
)});
Var
var
=
Var
(
std
::
string
{
char
(
'i'
+
i
)});
dst_vars
.
push_back
(
IterVar
(
Range
(
0
,
dst_layout
->
InputShape
()[
i
]),
var
,
dst_vars
.
push_back
(
IterVar
(
Range
(
0
,
dst_layout
->
InputShape
()[
i
]),
var
,
IterVarType
::
kDataPar
));
IterVarType
::
kDataPar
));
}
}
Array
<
IterVar
>
src_vars
;
if
(
!
is_1d_reduce
)
{
src_vars
=
dst_vars
;
}
src_vars
.
insert
(
src_vars
.
begin
()
+
this
->
dim
,
{
Range
(
0
,
src_layout
->
InputShape
()[
this
->
dim
]),
Var
(
"rv"
),
IterVarType
::
kDataPar
});
Array
<
PrimExpr
>
src_indices
=
src_layout
->
Forward
(
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
Array
<
PrimExpr
>
dst_indices
=
dst_layout
->
Forward
(
dst_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
Array
<
Stmt
>
stmts
;
bool
require_init
=
this
->
clear
;
// sum op must be cleared
if
(
this
->
type
->
isSum
())
{
require_init
=
true
;
}
else
if
(
this
->
type
->
isAbsSum
())
{
require_init
=
true
;
}
else
if
(
this
->
type
->
isBitAnd
())
{
require_init
=
true
;
}
else
if
(
this
->
type
->
isBitOr
())
{
require_init
=
true
;
}
else
if
(
this
->
type
->
isBitXor
())
{
require_init
=
true
;
}
Buffer
clear_buffer
=
dst_buffer
;
Array
<
IterVar
>
src_vars
;
bool
need_duplicate
=
false
;
if
(
!
is_1d_reduce
)
{
if
(
this
->
type
->
isSum
()
&&
!
this
->
clear
)
{
src_vars
=
dst_vars
;
need_duplicate
=
true
;
}
}
else
if
(
this
->
type
->
isAbsSum
()
&&
!
this
->
clear
)
{
Range
reduce_dom
(
0
,
src_layout
->
InputShape
()[
this
->
dim
]);
need_duplicate
=
true
;
IterVar
reduce_iv
(
reduce_dom
,
Var
(
"rv"
),
IterVarType
::
kDataPar
);
}
else
if
(
this
->
type
->
isBitAnd
())
{
src_vars
.
insert
(
src_vars
.
begin
()
+
this
->
dim
,
reduce_iv
);
need_duplicate
=
true
;
}
else
if
(
this
->
type
->
isBitOr
()
&&
!
this
->
clear
)
{
Array
<
PrimExpr
>
src_indices
=
src_layout
->
Forward
(
need_duplicate
=
true
;
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
}
else
if
(
this
->
type
->
isBitXor
()
&&
!
this
->
clear
)
{
Array
<
PrimExpr
>
dst_indices
=
dst_layout
->
Forward
(
need_duplicate
=
true
;
dst_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}));
}
Array
<
Stmt
>
stmts
;
bool
require_init
=
this
->
clear
;
if
(
this
->
type
->
isSum
()
||
this
->
type
->
isAbsSum
()
||
this
->
type
->
isBitAnd
()
||
this
->
type
->
isBitOr
()
||
this
->
type
->
isBitXor
())
{
require_init
=
true
;
}
if
(
need_duplicate
)
{
Buffer
clear_buffer
=
dst_buffer
;
// Create a new buffer with same shape and dtype as dst_buffer
bool
need_duplicate
=
false
;
clear_buffer
=
decl_buffer
(
dst_buffer
->
shape
,
dst_buffer
->
dtype
,
if
((
this
->
type
->
isSum
()
||
this
->
type
->
isAbsSum
())
&&
!
this
->
clear
)
{
dst_buffer
->
name
+
"_clear"
,
need_duplicate
=
true
;
GetPtrStorageScope
(
dst_buffer
->
data
));
}
else
if
(
this
->
type
->
isBitAnd
()
&&
!
this
->
clear
)
{
}
need_duplicate
=
true
;
}
else
if
((
this
->
type
->
isBitOr
()
||
this
->
type
->
isBitXor
())
&&
!
this
->
clear
)
{
need_duplicate
=
true
;
}
// make reduce-init stmt
if
(
need_duplicate
)
{
if
(
require_init
)
{
// Create a new buffer with same shape and dtype as dst_buffer
stmts
.
push_back
(
clear_buffer
=
decl_buffer
(
dst_buffer
->
shape
,
dst_buffer
->
dtype
,
BufferStore
(
clear_buffer
,
this
->
MakeInitValue
(),
dst_indices
));
dst_buffer
->
name
+
"_clear"
,
}
GetPtrStorageScope
(
dst_buffer
->
data
));
}
// make reduce-init stmt
if
(
require_init
)
{
stmts
.
push_back
(
BufferStore
(
clear_buffer
,
this
->
MakeInitValue
(),
dst_indices
));
}
// make thread-local reduce
// make thread-local reduce
Array
<
PrimExpr
>
src_indice_compressed
;
Array
<
PrimExpr
>
src_indice_compressed
;
Array
<
IterVar
>
src_var_compressed
;
Array
<
IterVar
>
src_var_compressed
;
for
(
size_t
i
=
0
;
i
<
src_layout
->
OutputDim
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
src_layout
->
OutputDim
();
++
i
)
{
PrimExpr
expr
;
PrimExpr
expr
;
IterVar
var
;
IterVar
var
;
std
::
tie
(
expr
,
var
)
=
CompressIterator
(
src_indices
[
i
],
src_vars
,
std
::
tie
(
expr
,
var
)
=
CompressIterator
(
src_vars
[
this
->
dim
]
->
var
,
analyzer
);
src_indices
[
i
],
src_vars
,
src_vars
[
this
->
dim
]
->
var
,
analyzer
);
src_indice_compressed
.
push_back
(
expr
);
src_indice_compressed
.
push_back
(
expr
);
src_var_compressed
.
push_back
(
var
);
src_var_compressed
.
push_back
(
var
);
}
}
Stmt
reduce_local
=
BufferStore
(
clear_buffer
,
Stmt
reduce_local
=
BufferStore
(
this
->
MakeReduce
(
BufferLoad
(
clear_buffer
,
dst_indices
),
clear_buffer
,
BufferLoad
(
src_buffer
,
src_indice_compressed
)),
this
->
MakeReduce
(
BufferLoad
(
clear_buffer
,
dst_indices
),
dst_indices
);
BufferLoad
(
src_buffer
,
src_indice_compressed
)),
for
(
int
i
=
src_layout
->
OutputDim
()
-
1
;
i
>=
0
;
i
--
)
{
dst_indices
);
reduce_local
=
For
(
src_var_compressed
[
i
]
->
var
,
0
,
src_var_compressed
[
i
]
->
dom
->
extent
,
for
(
int
i
=
static_cast
<
int
>
(
src_layout
->
OutputDim
())
-
1
;
i
>=
0
;
--
i
)
{
ForKind
::
kUnrolled
,
reduce_local
,
std
::
nullopt
,
reduce_local
=
{{
tir
::
attr
::
pragma_unroll_explicit
,
Bool
(
false
)}});
For
(
src_var_compressed
[
i
]
->
var
,
0
,
src_var_compressed
[
i
]
->
dom
->
extent
,
}
ForKind
::
kUnrolled
,
reduce_local
,
std
::
nullopt
,
stmts
.
push_back
(
reduce_local
);
{{
tir
::
attr
::
pragma_unroll_explicit
,
Bool
(
false
)}});
}
// make inter-thread reduce
stmts
.
push_back
(
reduce_local
);
PrimExpr
src_thread
=
src_layout
->
ForwardThread
(
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}),
{});
PrimExpr
src_thread
=
src_layout
->
ForwardThread
(
auto
iter_sum
=
src_vars
.
Map
([](
const
auto
&
iv
)
{
return
PrimExpr
(
iv
->
var
);
}),
{});
arith
::
NormalizeToIterSum
(
src_thread
,
ToVMap
(
src_vars
),
analyzer
);
auto
iter_sum
=
for
(
const
auto
&
iter_split
:
iter_sum
->
args
)
{
arith
::
NormalizeToIterSum
(
src_thread
,
ToVMap
(
src_vars
),
analyzer
);
auto
mark
=
iter_split
->
source
->
source
.
as
<
Var
>
();
for
(
const
auto
&
iter_split
:
iter_sum
->
args
)
{
ICHECK
(
mark
)
<<
"Not a normalized iterator: "
<<
iter_split
->
source
;
auto
mark
=
iter_split
->
source
->
source
.
as
<
Var
>
();
if
(
mark
.
value
().
same_as
(
src_vars
[
this
->
dim
]
->
var
))
{
ICHECK
(
mark
)
<<
"Not a normalized iterator: "
<<
iter_split
->
source
;
auto
scale
=
as_const_int
(
iter_split
->
scale
);
if
(
mark
.
value
().
same_as
(
src_vars
[
this
->
dim
]
->
var
))
{
auto
extent
=
as_const_int
(
iter_split
->
extent
);
auto
scale
=
as_const_int
(
iter_split
->
scale
);
ICHECK
(
scale
!=
nullptr
&&
extent
!=
nullptr
);
auto
extent
=
as_const_int
(
iter_split
->
extent
);
if
(
*
extent
==
1
)
ICHECK
(
scale
!=
nullptr
&&
extent
!=
nullptr
);
continue
;
if
(
*
extent
==
1
)
continue
;
int
reducing_threads
=
(
*
extent
)
*
(
*
scale
);
std
::
stringstream
ss
;
int
reducing_threads
=
(
*
extent
)
*
(
*
scale
);
std
::
stringstream
ss
;
auto
thread_offset
=
T
.
thread_bounds
->
min
;
if
(
TargetIsHopper
(
T
.
target
)
||
TargetIsSm100
(
T
.
target
))
{
auto
thread_offset
=
T
.
thread_bounds
->
min
;
auto
all_threads
=
T
.
thread_bounds
->
extent
;
if
(
TargetIsHopper
(
T
.
target
)
||
TargetIsSm100
(
T
.
target
))
{
ss
<<
"tl::AllReduce<"
<<
this
->
MakeCodegenReducer
()
<<
", "
auto
all_threads
=
T
.
thread_bounds
->
extent
;
<<
reducing_threads
<<
", "
<<
(
*
scale
)
<<
", "
<<
thread_offset
ss
<<
"tl::AllReduce<"
<<
this
->
MakeCodegenReducer
()
<<
", "
<<
", "
<<
all_threads
<<
">::run_hopper"
;
<<
reducing_threads
<<
", "
<<
(
*
scale
)
<<
", "
<<
thread_offset
}
else
{
<<
", "
<<
all_threads
<<
">::run_hopper"
;
ss
<<
"tl::AllReduce<"
<<
this
->
MakeCodegenReducer
()
<<
", "
}
else
{
<<
reducing_threads
<<
", "
<<
(
*
scale
)
<<
", "
<<
thread_offset
ss
<<
"tl::AllReduce<"
<<
this
->
MakeCodegenReducer
()
<<
", "
<<
">::run"
;
<<
reducing_threads
<<
", "
<<
(
*
scale
)
<<
", "
<<
thread_offset
}
<<
">::run"
;
Array
<
PrimExpr
>
thread_reduce_args
=
{
}
StringImm
(
ss
.
str
()),
BufferLoad
(
clear_buffer
,
dst_indices
)};
Array
<
PrimExpr
>
thread_reduce_args
=
{
if
(
reducing_threads
>=
32
)
{
StringImm
(
ss
.
str
()),
BufferLoad
(
clear_buffer
,
dst_indices
)};
PrimExpr
workspace
=
T
.
AddWorkspace
(
if
(
reducing_threads
>=
32
)
{
*
as_const_int
(
T
.
thread_bounds
->
extent
),
clear_buffer
->
dtype
);
PrimExpr
workspace
=
T
.
AddWorkspace
(
thread_reduce_args
.
push_back
(
workspace
);
*
as_const_int
(
T
.
thread_bounds
->
extent
),
clear_buffer
->
dtype
);
thread_reduce_args
.
push_back
(
workspace
);
}
auto
call
=
Call
(
clear_buffer
->
dtype
,
builtin
::
call_extern
(),
thread_reduce_args
);
stmts
.
push_back
(
BufferStore
(
clear_buffer
,
call
,
dst_indices
));
}
}
auto
call
=
Call
(
clear_buffer
->
dtype
,
builtin
::
call_extern
(),
thread_reduce_args
);
stmts
.
push_back
(
BufferStore
(
clear_buffer
,
call
,
dst_indices
));
}
}
}
Stmt
reduce_interthread
=
BufferStore
(
if
(
need_duplicate
)
{
clear_buffer
,
BufferLoad
(
clear_buffer
,
dst_indices
),
dst_indices
);
PrimExpr
src_val
=
BufferLoad
(
clear_buffer
,
dst_indices
);
PrimExpr
dst_val
=
BufferLoad
(
dst_buffer
,
dst_indices
);
// copy clear_buffer to dst_buffer
PrimExpr
update
;
if
(
need_duplicate
)
{
if
(
this
->
type
->
isSum
()
||
this
->
type
->
isAbsSum
())
{
// if is reduce sum, we should add a copy from clear_buffer to dst_buffer
update
=
dst_val
+
src_val
;
if
(
this
->
type
->
isSum
())
{
}
else
if
(
this
->
type
->
isBitAnd
())
{
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
update
=
this
->
clear
?
src_val
:
bitwise_and
(
dst_val
,
src_val
);
Add
(
BufferLoad
(
dst_buffer
,
dst_indices
),
}
else
if
(
this
->
type
->
isBitOr
())
{
BufferLoad
(
clear_buffer
,
dst_indices
)),
update
=
bitwise_or
(
dst_val
,
src_val
);
dst_indices
));
}
else
if
(
this
->
type
->
isBitXor
())
{
}
else
if
(
this
->
type
->
isAbsSum
())
{
update
=
bitwise_xor
(
dst_val
,
src_val
);
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
Add
(
BufferLoad
(
dst_buffer
,
dst_indices
),
BufferLoad
(
clear_buffer
,
dst_indices
)),
dst_indices
));
}
else
if
(
this
->
type
->
isBitAnd
())
{
if
(
!
this
->
clear
)
{
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
bitwise_and
(
BufferLoad
(
dst_buffer
,
dst_indices
),
BufferLoad
(
clear_buffer
,
dst_indices
)),
dst_indices
));
}
else
{
}
else
{
stmts
.
push_back
(
BufferStore
(
LOG
(
FATAL
)
<<
"Unsupported reduce type: "
<<
this
->
type
->
type
;
dst_buffer
,
BufferLoad
(
clear_buffer
,
dst_indices
),
dst_indices
));
}
}
}
else
if
(
this
->
type
->
isBitOr
())
{
stmts
.
push_back
(
BufferStore
(
dst_buffer
,
update
,
dst_indices
));
stmts
.
push_back
(
}
BufferStore
(
dst_buffer
,
bitwise_or
(
BufferLoad
(
dst_buffer
,
dst_indices
),
Stmt
body
=
stmts
.
size
()
>
1
?
SeqStmt
(
stmts
)
:
stmts
[
0
];
BufferLoad
(
clear_buffer
,
dst_indices
)),
for
(
int
i
=
static_cast
<
int
>
(
dst_layout
->
InputDim
())
-
1
;
i
>=
0
;
--
i
)
{
dst_indices
));
body
=
For
(
dst_vars
[
i
]
->
var
,
0
,
dst_vars
[
i
]
->
dom
->
extent
,
}
else
if
(
this
->
type
->
isBitXor
())
{
ForKind
::
kParallel
,
body
);
stmts
.
push_back
(
}
BufferStore
(
dst_buffer
,
bitwise_xor
(
BufferLoad
(
dst_buffer
,
dst_indices
),
if
(
dst_layout
->
InputDim
()
>
0
)
{
BufferLoad
(
clear_buffer
,
dst_indices
))
,
body
=
PartitionLoop
(
Downcast
<
For
>
(
body
),
T
.
thread_var
,
analyzer
,
dst_indices
)
);
dst_layout
);
}
else
{
}
else
{
ICHECK
(
false
)
<<
"Unsupported reduce type: "
<<
this
->
type
->
type
;
PrimExpr
guard
=
(
T
.
thread_var
==
T
.
thread_bounds
->
min
);
body
=
IfThenElse
(
guard
,
body
);
}
}
}
// make the outer spatial loop
if
(
need_duplicate
)
{
Stmt
body
=
stmts
.
size
()
>
1
?
SeqStmt
(
stmts
)
:
stmts
[
0
];
body
=
Allocate
(
clear_buffer
->
data
,
clear_buffer
->
dtype
,
for
(
int
i
=
dst_layout
->
InputDim
()
-
1
;
i
>=
0
;
i
--
)
{
clear_buffer
->
shape
,
const_true
(),
body
);
body
=
For
(
dst_vars
[
i
]
->
var
,
0
,
dst_vars
[
i
]
->
dom
->
extent
,
}
ForKind
::
kParallel
,
body
)
;
return
body
;
}
}
body
=
PartitionLoop
(
Downcast
<
For
>
(
body
),
T
.
thread_var
,
analyzer
,
dst_layout
);
auto
is_shared_scope
=
[](
const
std
::
string
&
scope
)
{
if
(
need_duplicate
)
{
return
scope
==
"shared"
||
scope
==
"shared.dyn"
;
body
=
Allocate
(
clear_buffer
->
data
,
clear_buffer
->
dtype
,
};
clear_buffer
->
shape
,
const_true
(),
body
);
if
(
is_shared_scope
(
src_scope
)
&&
is_shared_scope
(
dst_scope
))
{
Buffer
src_buffer
=
get_buffer
(
this
->
src
);
Buffer
dst_buffer
=
get_buffer
(
this
->
dst
);
size_t
src_dim
=
src_buffer
->
shape
.
size
();
size_t
dst_dim
=
dst_buffer
->
shape
.
size
();
bool
is_1d_reduce
=
(
src_dim
==
dst_dim
&&
dst_dim
==
1
);
if
(
!
is_1d_reduce
)
{
ICHECK_EQ
(
src_dim
,
dst_dim
+
1
)
<<
"Reduce dimension mismatch."
;
}
else
{
ICHECK_EQ
(
dst_dim
,
1U
)
<<
"Expect scalar layout for 1D reduce."
;
}
auto
thread_extent
=
as_const_int
(
T
.
thread_bounds
->
extent
);
ICHECK
(
thread_extent
)
<<
"Shared-memory reduce requires static thread extent."
;
int
threads
=
*
thread_extent
;
if
(
TargetIsCuda
(
T
.
target
))
{
ICHECK_EQ
(
threads
%
32
,
0
)
<<
"Shared reduce expects blockDim.x to be a multiple of 32 on CUDA."
;
}
else
if
(
TargetIsRocm
(
T
.
target
))
{
ICHECK_EQ
(
threads
%
64
,
0
)
<<
"Shared reduce expects blockDim.x to be a multiple of 64 on HIP."
;
}
bool
use_abs
=
this
->
type
->
isAbsSum
()
||
this
->
type
->
isAbsMax
();
bool
need_accumulate
=
(
!
this
->
clear
)
&&
(
this
->
type
->
isSum
()
||
this
->
type
->
isAbsSum
()
||
this
->
type
->
isBitAnd
()
||
this
->
type
->
isBitOr
()
||
this
->
type
->
isBitXor
());
PrimExpr
reduce_extent
=
src_buffer
->
shape
[
this
->
dim
];
PrimExpr
tail_extent
=
make_const
(
DataType
::
Int
(
32
),
1
);
for
(
size_t
i
=
this
->
dim
+
1
;
i
<
src_dim
;
++
i
)
{
tail_extent
=
analyzer
->
Simplify
(
tail_extent
*
src_buffer
->
shape
[
i
]);
}
PrimExpr
total_dest
=
make_const
(
DataType
::
Int
(
32
),
1
);
for
(
size_t
i
=
0
;
i
<
dst_dim
;
++
i
)
{
total_dest
=
analyzer
->
Simplify
(
total_dest
*
dst_buffer
->
shape
[
i
]);
}
std
::
stringstream
ss
;
std
::
string
reducer
=
this
->
MakeCodegenReducer
();
ss
<<
"tl::SharedReduceWarp<"
<<
reducer
<<
", "
<<
threads
<<
", "
<<
(
use_abs
?
"true"
:
"false"
)
<<
", "
<<
(
need_accumulate
?
"true"
:
"false"
)
<<
">::run"
;
Array
<
PrimExpr
>
call_args
=
{
StringImm
(
ss
.
str
()),
src_buffer
.
access_ptr
(
1
),
dst_buffer
.
access_ptr
(
3
),
cast
(
DataType
::
Int
(
32
),
total_dest
),
cast
(
DataType
::
Int
(
32
),
reduce_extent
),
cast
(
DataType
::
Int
(
32
),
tail_extent
),
this
->
MakeInitValue
()};
return
Evaluate
(
Call
(
dst_buffer
->
dtype
,
builtin
::
call_extern
(),
call_args
));
}
}
return
body
;
LOG
(
FATAL
)
<<
"Reduce for buffers in scope ("
<<
src_scope
<<
", "
<<
dst_scope
<<
") is not implemented."
;
return
Stmt
();
}
}
LayoutMap
ReduceOpNode
::
InferLayout
(
const
LayoutInferArgs
&
T
,
LayoutMap
ReduceOpNode
::
InferLayout
(
const
LayoutInferArgs
&
T
,
...
...
src/tl_templates/cuda/reduce.h
View file @
bc37ea69
...
@@ -40,6 +40,53 @@ struct BitXorOp {
...
@@ -40,6 +40,53 @@ struct BitXorOp {
}
}
};
};
template
<
class
Reducer
,
int
Threads
,
bool
UseAbs
,
bool
NeedAccumulate
>
struct
SharedReduceWarp
{
template
<
typename
T
>
static
TL_DEVICE
void
run
(
const
T
*
__restrict__
src
,
T
*
__restrict__
dst
,
int
total_dest
,
int
reduce_extent
,
int
tail
,
T
init_value
)
{
if
(
total_dest
<=
0
||
reduce_extent
<=
0
)
return
;
constexpr
int
kWarpSize
=
32
;
static_assert
(
Threads
%
kWarpSize
==
0
,
"SharedReduceWarp expects blockDim.x to be a multiple of "
"warp size on CUDA."
);
const
int
tid
=
threadIdx
.
x
;
const
int
warp_id
=
tid
/
kWarpSize
;
const
int
lane
=
tid
%
kWarpSize
;
const
int
num_warps
=
Threads
/
kWarpSize
;
for
(
int
dest_idx
=
warp_id
;
dest_idx
<
total_dest
;
dest_idx
+=
num_warps
)
{
const
int
prefix
=
tail
==
1
?
dest_idx
:
dest_idx
/
tail
;
const
int
suffix
=
tail
==
1
?
0
:
dest_idx
%
tail
;
const
int
src_base
=
(
prefix
*
reduce_extent
)
*
tail
+
suffix
;
const
int
dst_index
=
prefix
*
tail
+
suffix
;
T
partial
=
init_value
;
for
(
int
rv
=
lane
;
rv
<
reduce_extent
;
rv
+=
kWarpSize
)
{
T
val
=
src
[
src_base
+
rv
*
tail
];
if
constexpr
(
UseAbs
)
{
val
=
val
<
T
(
0
)
?
-
val
:
val
;
}
partial
=
Reducer
()(
partial
,
val
);
}
unsigned
mask
=
__activemask
();
for
(
int
offset
=
kWarpSize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
T
other
=
__shfl_down_sync
(
mask
,
partial
,
offset
);
partial
=
Reducer
()(
partial
,
other
);
}
if
(
lane
==
0
)
{
if
constexpr
(
NeedAccumulate
)
{
partial
=
Reducer
()(
dst
[
dst_index
],
partial
);
}
dst
[
dst_index
]
=
partial
;
}
}
}
};
template
<
class
Reducer
,
int
threads
,
int
scale
,
int
thread_offset
=
0
,
template
<
class
Reducer
,
int
threads
,
int
scale
,
int
thread_offset
=
0
,
int
all_threads
=
threads
>
int
all_threads
=
threads
>
struct
AllReduce
{
struct
AllReduce
{
...
...
src/tl_templates/hip/reduce.h
View file @
bc37ea69
...
@@ -22,6 +22,71 @@ struct MinOp {
...
@@ -22,6 +22,71 @@ struct MinOp {
}
}
};
};
struct
BitAndOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
&
y
;
}
};
struct
BitOrOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
|
y
;
}
};
struct
BitXorOp
{
template
<
typename
T
>
TL_DEVICE
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
^
y
;
}
};
template
<
class
Reducer
,
int
Threads
,
bool
UseAbs
,
bool
NeedAccumulate
>
struct
SharedReduceWarp
{
template
<
typename
T
>
static
TL_DEVICE
void
run
(
const
T
*
__restrict__
src
,
T
*
__restrict__
dst
,
int
total_dest
,
int
reduce_extent
,
int
tail
,
T
init_value
)
{
if
(
total_dest
<=
0
||
reduce_extent
<=
0
)
return
;
constexpr
int
kWarpSize
=
64
;
static_assert
(
Threads
%
kWarpSize
==
0
,
"SharedReduceWarp expects blockDim.x to be a multiple of "
"wave size on HIP."
);
const
int
tid
=
threadIdx
.
x
;
const
int
warp_id
=
tid
/
kWarpSize
;
const
int
lane
=
tid
%
kWarpSize
;
const
int
num_warps
=
Threads
/
kWarpSize
;
for
(
int
dest_idx
=
warp_id
;
dest_idx
<
total_dest
;
dest_idx
+=
num_warps
)
{
const
int
prefix
=
tail
==
1
?
dest_idx
:
dest_idx
/
tail
;
const
int
suffix
=
tail
==
1
?
0
:
dest_idx
%
tail
;
const
int
src_base
=
(
prefix
*
reduce_extent
)
*
tail
+
suffix
;
const
int
dst_index
=
prefix
*
tail
+
suffix
;
T
partial
=
init_value
;
for
(
int
rv
=
lane
;
rv
<
reduce_extent
;
rv
+=
kWarpSize
)
{
T
val
=
src
[
src_base
+
rv
*
tail
];
if
constexpr
(
UseAbs
)
{
val
=
val
<
T
(
0
)
?
-
val
:
val
;
}
partial
=
Reducer
()(
partial
,
val
);
}
for
(
int
offset
=
kWarpSize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
T
other
=
__shfl_down
(
partial
,
offset
,
kWarpSize
);
partial
=
Reducer
()(
partial
,
other
);
}
if
(
lane
==
0
)
{
if
constexpr
(
NeedAccumulate
)
{
partial
=
Reducer
()(
dst
[
dst_index
],
partial
);
}
dst
[
dst_index
]
=
partial
;
}
}
}
};
template
<
class
Reducer
,
int
threads
,
int
scale
,
int
thread_offset
=
0
>
template
<
class
Reducer
,
int
threads
,
int
scale
,
int
thread_offset
=
0
>
struct
AllReduce
{
struct
AllReduce
{
static_assert
(
threads
==
1024
||
threads
==
512
||
threads
==
256
||
static_assert
(
threads
==
1024
||
threads
==
512
||
threads
==
256
||
...
...
testing/python/language/test_tilelang_language_reduce.py
0 → 100644
View file @
bc37ea69
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
import
tilelang
as
tl
tilelang
.
testing
.
set_random_seed
()
def
_make_shared_reduce
(
M
,
N
,
dtype
,
reduce_cb
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_shared
=
T
.
alloc_shared
((
M
,
N
),
dtype
)
B_shared
=
T
.
alloc_shared
((
M
,),
dtype
)
T
.
copy
(
A
,
A_shared
)
reduce_cb
(
T
,
A_shared
,
B_shared
)
T
.
copy
(
B_shared
,
B
)
return
main
def
_run_program
(
program
,
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
):
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
)
profiler
=
jit_kernel
.
get_profiler
()
profiler
.
assert_allclose
(
ref_program
,
atol
=
atol
,
rtol
=
rtol
)
def
reduce_max_test
(
M
,
N
,
dtype
=
"float16"
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
B_local
=
T
.
alloc_fragment
((
M
,),
dtype
)
T
.
copy
(
A
,
A_local
)
T
.
reduce_max
(
A_local
,
B_local
,
dim
=
1
)
T
.
copy
(
B_local
,
B
)
return
main
def
reduce_sum_test
(
M
,
N
,
dtype
=
"float32"
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
B_local
=
T
.
alloc_fragment
((
M
,),
dtype
)
T
.
copy
(
A
,
A_local
)
T
.
reduce_sum
(
A_local
,
B_local
,
dim
=
1
)
T
.
copy
(
B_local
,
B
)
return
main
def
reduce_sum_ss
(
M
,
N
,
dtype
=
"float32"
):
return
_make_shared_reduce
(
M
,
N
,
dtype
,
lambda
T
,
src
,
dst
:
T
.
reduce_sum
(
src
,
dst
,
dim
=
1
))
def
reduce_max_ss
(
M
,
N
,
dtype
=
"float32"
):
return
_make_shared_reduce
(
M
,
N
,
dtype
,
lambda
T
,
src
,
dst
:
T
.
reduce_max
(
src
,
dst
,
dim
=
1
))
def
reduce_min_ss
(
M
,
N
,
dtype
=
"float32"
):
return
_make_shared_reduce
(
M
,
N
,
dtype
,
lambda
T
,
src
,
dst
:
T
.
reduce_min
(
src
,
dst
,
dim
=
1
))
def
reduce_abssum_ss
(
M
,
N
,
dtype
=
"float32"
):
return
_make_shared_reduce
(
M
,
N
,
dtype
,
lambda
T
,
src
,
dst
:
T
.
reduce_abssum
(
src
,
dst
,
dim
=
1
))
def
reduce_absmax_ss
(
M
,
N
,
dtype
=
"float32"
):
return
_make_shared_reduce
(
M
,
N
,
dtype
,
lambda
T
,
src
,
dst
:
T
.
reduce_absmax
(
src
,
dst
,
dim
=
1
))
def
run_reduce_sum
(
M
,
N
,
dtype
=
"float32"
,
mode
=
"rr"
):
if
mode
==
"rr"
:
program
=
reduce_sum_test
(
M
,
N
,
dtype
)
elif
mode
==
"ss"
:
program
=
reduce_sum_ss
(
M
,
N
,
dtype
)
else
:
raise
NotImplementedError
(
"run_reduce_sum only supports rr and ss"
)
_run_program
(
program
,
lambda
A
:
A
.
sum
(
dim
=
1
))
def
run_shared_reduce
(
program_builder
,
ref_program
,
M
,
N
,
dtype
=
"float32"
):
program
=
program_builder
(
M
,
N
,
dtype
)
_run_program
(
program
,
ref_program
)
def
run_reduce_max
(
M
,
N
,
dtype
=
"float16"
):
program
=
reduce_max_test
(
M
,
N
,
dtype
)
_run_program
(
program
,
lambda
A
:
A
.
max
(
dim
=
1
).
values
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reduce_sum
():
run_reduce_sum
(
256
,
256
)
run_reduce_sum
(
512
,
128
)
run_reduce_sum
(
128
,
512
)
def
test_reduce_sum_shared
():
run_reduce_sum
(
64
,
64
,
mode
=
"ss"
)
run_reduce_sum
(
32
,
96
,
mode
=
"ss"
)
def
test_reduce_max
():
run_reduce_max
(
256
,
256
,
"float16"
)
run_reduce_max
(
512
,
128
,
"float16"
)
run_reduce_max
(
256
,
256
,
"float32"
)
def
test_reduce_max_shared
():
run_shared_reduce
(
reduce_max_ss
,
lambda
A
:
A
.
max
(
dim
=
1
).
values
,
64
,
64
,
"float32"
)
run_shared_reduce
(
reduce_max_ss
,
lambda
A
:
A
.
max
(
dim
=
1
).
values
,
96
,
48
,
"float32"
)
def
test_reduce_min_shared
():
run_shared_reduce
(
reduce_min_ss
,
lambda
A
:
A
.
min
(
dim
=
1
).
values
,
64
,
64
,
"float32"
)
def
test_reduce_abssum_shared
():
run_shared_reduce
(
reduce_abssum_ss
,
lambda
A
:
A
.
abs
().
sum
(
dim
=
1
),
64
,
64
,
"float32"
)
def
test_reduce_absmax_shared
():
run_shared_reduce
(
reduce_absmax_ss
,
lambda
A
:
A
.
abs
().
max
(
dim
=
1
).
values
,
64
,
64
,
"float32"
)
def
reduce_sum_test_clear
(
M
,
N
,
dtype
=
"float32"
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
B_local
=
T
.
alloc_fragment
((
M
,),
dtype
)
T
.
copy
(
A
,
A_local
)
T
.
fill
(
B_local
,
1
)
T
.
reduce_sum
(
A_local
,
B_local
,
dim
=
1
,
clear
=
False
)
T
.
copy
(
B_local
,
B
)
return
main
def
run_reduce_sum_clear
(
M
,
N
,
dtype
=
"float32"
):
program
=
reduce_sum_test_clear
(
M
,
N
,
dtype
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
)
def
ref_program
(
A
):
return
A
.
sum
(
dim
=
1
)
+
1
import
torch
dummy_A
=
torch
.
randn
((
M
,
N
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
ref_out
=
ref_program
(
dummy_A
)
tl_out
=
jit_kernel
(
dummy_A
)
torch
.
testing
.
assert_close
(
tl_out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reduce_sum_clear
():
run_reduce_sum_clear
(
256
,
256
,
"float32"
)
run_reduce_sum_clear
(
512
,
128
,
"float32"
)
run_reduce_sum_clear
(
128
,
512
,
"float32"
)
def
reduce_max_test_clear
(
M
,
N
,
dtype
=
"float16"
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
B_local
=
T
.
alloc_fragment
((
M
,),
dtype
)
T
.
copy
(
A
,
A_local
)
T
.
fill
(
B_local
,
-
T
.
infinity
(
dtype
))
T
.
reduce_max
(
A_local
,
B_local
,
dim
=
1
,
clear
=
False
)
T
.
copy
(
B_local
,
B
)
return
main
def
run_reduce_max_clear
(
M
,
N
,
dtype
=
"float16"
):
program
=
reduce_max_test_clear
(
M
,
N
,
dtype
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
)
def
ref_program
(
A
):
return
A
.
max
(
dim
=
1
).
values
import
torch
dummy_A
=
torch
.
randn
((
M
,
N
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
ref_out
=
ref_program
(
dummy_A
)
tl_out
=
jit_kernel
(
dummy_A
)
torch
.
testing
.
assert_close
(
tl_out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reduce_max_clear
():
run_reduce_max_clear
(
256
,
256
,
"float16"
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_reduce_max.py
deleted
100644 → 0
View file @
a7730272
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
import
tilelang
as
tl
def
reduce_max_test
(
M
,
N
,
dtype
=
"float16"
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
B_local
=
T
.
alloc_fragment
((
M
,),
dtype
)
# Copy input to local
T
.
copy
(
A
,
A_local
)
# Perform reduce_max operation
T
.
reduce_max
(
A_local
,
B_local
,
dim
=
1
)
# Copy result back
T
.
copy
(
B_local
,
B
)
return
main
def
run_reduce_max
(
M
,
N
,
dtype
=
"float16"
):
program
=
reduce_max_test
(
M
,
N
,
dtype
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
)
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
return
A
.
max
(
dim
=
1
).
values
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reduce_max
():
# Test different sizes
run_reduce_max
(
256
,
256
)
run_reduce_max
(
512
,
128
)
run_reduce_max
(
128
,
512
)
# Test different dtypes
run_reduce_max
(
256
,
256
,
"float32"
)
run_reduce_max
(
256
,
256
,
"float16"
)
def
reduce_max_test_clear
(
M
,
N
,
dtype
=
"float16"
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
B_local
=
T
.
alloc_fragment
((
M
,),
dtype
)
T
.
copy
(
A
,
A_local
)
T
.
fill
(
B_local
,
-
T
.
infinity
(
dtype
))
T
.
reduce_max
(
A_local
,
B_local
,
dim
=
1
,
clear
=
False
)
T
.
copy
(
B_local
,
B
)
return
main
def
run_reduce_max_clear
(
M
,
N
,
dtype
=
"float16"
):
program
=
reduce_max_test_clear
(
M
,
N
,
dtype
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
)
print
(
jit_kernel
.
get_kernel_source
())
def
ref_program
(
A
):
return
A
.
max
(
dim
=
1
).
values
import
torch
dummp_A
=
torch
.
randn
((
M
,
N
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
ref_out
=
ref_program
(
dummp_A
)
tl_out
=
jit_kernel
(
dummp_A
)
print
(
tl_out
)
print
(
ref_out
)
torch
.
testing
.
assert_close
(
tl_out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reduce_max_clear
():
run_reduce_max_clear
(
256
,
256
,
"float16"
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
testing/python/language/test_tilelang_language_reduce_sum.py
deleted
100644 → 0
View file @
a7730272
from
tilelang
import
tvm
as
tvm
import
tilelang.testing
import
tilelang
as
tl
tilelang
.
testing
.
set_random_seed
()
def
reduce_sum_test
(
M
,
N
,
dtype
=
"float32"
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
B_local
=
T
.
alloc_fragment
((
M
,),
dtype
)
# Copy input to local
T
.
copy
(
A
,
A_local
)
# Perform reduce_sum operation
T
.
reduce_sum
(
A_local
,
B_local
,
dim
=
1
)
# Copy result back
T
.
copy
(
B_local
,
B
)
return
main
def
run_reduce_sum
(
M
,
N
,
dtype
=
"float32"
):
program
=
reduce_sum_test
(
M
,
N
,
dtype
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
)
profiler
=
jit_kernel
.
get_profiler
()
def
ref_program
(
A
):
return
A
.
sum
(
dim
=
1
)
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reduce_sum
():
# Test different sizes
run_reduce_sum
(
256
,
256
)
run_reduce_sum
(
512
,
128
)
run_reduce_sum
(
128
,
512
)
def
reduce_sum_test_clear
(
M
,
N
,
dtype
=
"float32"
):
import
tilelang.language
as
T
@
T
.
prim_func
def
main
(
A
:
T
.
Tensor
((
M
,
N
),
dtype
),
B
:
T
.
Tensor
((
M
,),
dtype
),
):
with
T
.
Kernel
(
1
,
threads
=
32
)
as
_
:
A_local
=
T
.
alloc_fragment
((
M
,
N
),
dtype
)
B_local
=
T
.
alloc_fragment
((
M
,),
dtype
)
T
.
copy
(
A
,
A_local
)
T
.
fill
(
B_local
,
1
)
T
.
reduce_sum
(
A_local
,
B_local
,
dim
=
1
,
clear
=
False
)
T
.
copy
(
B_local
,
B
)
return
main
def
run_reduce_sum_clear
(
M
,
N
,
dtype
=
"float32"
):
program
=
reduce_sum_test_clear
(
M
,
N
,
dtype
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
)
def
ref_program
(
A
):
return
A
.
sum
(
dim
=
1
)
+
1
import
torch
dummp_A
=
torch
.
randn
((
M
,
N
),
dtype
=
getattr
(
torch
,
dtype
)).
cuda
()
ref_out
=
ref_program
(
dummp_A
)
tl_out
=
jit_kernel
(
dummp_A
)
torch
.
testing
.
assert_close
(
tl_out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reduce_sum_clear
():
run_reduce_sum_clear
(
256
,
256
,
"float32"
)
run_reduce_sum_clear
(
512
,
128
,
"float32"
)
run_reduce_sum_clear
(
128
,
512
,
"float32"
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang/jit/adapter/wrapper.py
View file @
bc37ea69
...
@@ -12,7 +12,7 @@ from tvm.tir.stmt_functor import post_order_visit
...
@@ -12,7 +12,7 @@ from tvm.tir.stmt_functor import post_order_visit
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY
=
"""
PREDEF_ATTRIBUTE_SET_DYNAMIC_MEMORY
=
"""
cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
cudaError_t result_{0} = cudaFuncSetAttribute({0}, cudaFuncAttributeMaxDynamicSharedMemorySize, {1});
if (result_{0} !=
CUDA_SUCCESS
) {{
if (result_{0} !=
cudaSuccess
) {{
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0}));
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", {1}, cudaGetErrorString(result_{0}));
return -1;
return -1;
}}
}}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment