Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
d0742860
Unverified
Commit
d0742860
authored
Aug 15, 2025
by
Gabriel Wu
Committed by
GitHub
Aug 15, 2025
Browse files
[Chore] fix typos (#719)
* chore: fix typos * chore: fix ruff * chore: fix clang-format
parent
6545b084
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
49 additions
and
62 deletions
+49
-62
benchmark/matmul/benchmark_matmul.py
benchmark/matmul/benchmark_matmul.py
+1
-4
benchmark/matmul/benchmark_matmul_intrinsic.py
benchmark/matmul/benchmark_matmul_intrinsic.py
+1
-4
docs/deeplearning_operators/gemv.md
docs/deeplearning_operators/gemv.md
+1
-1
examples/analyze/example_conv_analyze.py
examples/analyze/example_conv_analyze.py
+2
-4
examples/analyze/example_gemm_analyze.py
examples/analyze/example_gemm_analyze.py
+1
-4
examples/bitnet-1.58b/modeling_bitnet.py
examples/bitnet-1.58b/modeling_bitnet.py
+1
-1
examples/gemm/example_gemm_autotune.py
examples/gemm/example_gemm_autotune.py
+1
-4
src/op/gemm_sp.cc
src/op/gemm_sp.cc
+1
-1
src/target/codegen_cpp.h
src/target/codegen_cpp.h
+1
-1
src/target/codegen_webgpu.cc
src/target/codegen_webgpu.cc
+3
-3
src/tl_templates/cpp/half.hpp
src/tl_templates/cpp/half.hpp
+13
-13
src/tl_templates/cuda/common.h
src/tl_templates/cuda/common.h
+1
-1
src/tl_templates/cuda/debug.h
src/tl_templates/cuda/debug.h
+1
-1
src/transform/atomicadd_vectorize.cc
src/transform/atomicadd_vectorize.cc
+1
-1
src/transform/merge_shared_memory_allocations.cc
src/transform/merge_shared_memory_allocations.cc
+4
-4
src/transform/storage_rewrite.cc
src/transform/storage_rewrite.cc
+9
-9
src/transform/thread_storage_sync.cc
src/transform/thread_storage_sync.cc
+1
-1
src/transform/vectorize_loop.cc
src/transform/vectorize_loop.cc
+3
-3
testing/python/language/test_tilelang_language_reshape.py
testing/python/language/test_tilelang_language_reshape.py
+2
-1
tilelang/autotuner/tuner.py
tilelang/autotuner/tuner.py
+1
-1
No files found.
benchmark/matmul/benchmark_matmul.py
View file @
d0742860
...
@@ -53,10 +53,7 @@ def get_configs(args, kwargs):
...
@@ -53,10 +53,7 @@ def get_configs(args, kwargs):
from
tilelang.carver.roller.rasterization
import
NoRasterization
from
tilelang.carver.roller.rasterization
import
NoRasterization
import
torch
import
torch
if
torch
.
version
.
hip
is
not
None
:
arch
=
CDNA
(
"cuda"
)
if
torch
.
version
.
hip
is
None
else
CUDA
(
"hip"
)
arch
=
CDNA
(
"hip"
)
else
:
arch
=
CUDA
(
"cuda"
)
topk
=
10
topk
=
10
carve_template
=
MatmulTemplate
(
carve_template
=
MatmulTemplate
(
...
...
benchmark/matmul/benchmark_matmul_intrinsic.py
View file @
d0742860
...
@@ -187,10 +187,7 @@ def get_configs(args, kwargs):
...
@@ -187,10 +187,7 @@ def get_configs(args, kwargs):
from
tilelang.carver.roller.rasterization
import
NoRasterization
from
tilelang.carver.roller.rasterization
import
NoRasterization
import
torch
import
torch
if
torch
.
version
.
hip
is
not
None
:
arch
=
CDNA
(
"cuda"
)
if
torch
.
version
.
hip
is
None
else
CUDA
(
"hip"
)
arch
=
CDNA
(
"hip"
)
else
:
arch
=
CUDA
(
"cuda"
)
topk
=
10
topk
=
10
carve_template
=
MatmulTemplate
(
carve_template
=
MatmulTemplate
(
...
...
docs/deeplearning_operators/gemv.md
View file @
d0742860
...
@@ -252,7 +252,7 @@ def splitk_gemv_vectorized(
...
@@ -252,7 +252,7 @@ def splitk_gemv_vectorized(
return
main
return
main
```
```
With vectorized read, now the kernel finishs in
**~0.0084 ms**
, which is getting close to cuBLAS performance.
With vectorized read, now the kernel finish
e
s in
**~0.0084 ms**
, which is getting close to cuBLAS performance.
## `tvm_thread_allreduce` Instead of `atomicAdd`
## `tvm_thread_allreduce` Instead of `atomicAdd`
...
...
examples/analyze/example_conv_analyze.py
View file @
d0742860
...
@@ -4,6 +4,7 @@ from tilelang.carver.arch import CUDA
...
@@ -4,6 +4,7 @@ from tilelang.carver.arch import CUDA
from
tilelang.carver.arch
import
CDNA
from
tilelang.carver.arch
import
CDNA
from
tilelang.layout
import
make_swizzled_layout
from
tilelang.layout
import
make_swizzled_layout
import
torch
import
torch
N
=
64
N
=
64
C
=
256
C
=
256
H
=
512
H
=
512
...
@@ -95,10 +96,7 @@ def kernel(N,
...
@@ -95,10 +96,7 @@ def kernel(N,
def
main
():
def
main
():
my_func
=
kernel
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
64
,
128
,
32
,
3
,
256
)
my_func
=
kernel
(
N
,
C
,
H
,
W
,
F
,
K
,
S
,
D
,
P
,
64
,
128
,
32
,
3
,
256
)
if
torch
.
version
.
hip
is
not
None
:
cuda_device
=
CDNA
(
"cuda"
)
if
torch
.
version
.
hip
is
None
else
CUDA
(
"hip"
)
cuda_device
=
CDNA
(
"hip"
)
else
:
cuda_device
=
CUDA
(
"cuda"
)
result
=
Analyzer
.
analysis
(
my_func
,
cuda_device
)
result
=
Analyzer
.
analysis
(
my_func
,
cuda_device
)
print
(
result
)
print
(
result
)
print
(
f
"Analyzed FLOPs:
{
result
.
total_flops
}
"
)
print
(
f
"Analyzed FLOPs:
{
result
.
total_flops
}
"
)
...
...
examples/analyze/example_gemm_analyze.py
View file @
d0742860
...
@@ -49,10 +49,7 @@ def kernel(
...
@@ -49,10 +49,7 @@ def kernel(
def
main
():
def
main
():
my_func
=
kernel
(
128
,
128
,
32
,
3
,
128
,
True
)
my_func
=
kernel
(
128
,
128
,
32
,
3
,
128
,
True
)
if
torch
.
version
.
hip
is
not
None
:
cuda_device
=
CDNA
(
"cuda"
)
if
torch
.
version
.
hip
is
None
else
CUDA
(
"hip"
)
cuda_device
=
CDNA
(
"hip"
)
else
:
cuda_device
=
CUDA
(
"cuda"
)
result
=
Analyzer
.
analysis
(
my_func
,
cuda_device
)
result
=
Analyzer
.
analysis
(
my_func
,
cuda_device
)
print
(
f
"Analyzed FLOPs:
{
result
.
total_flops
}
"
)
print
(
f
"Analyzed FLOPs:
{
result
.
total_flops
}
"
)
...
...
examples/bitnet-1.58b/modeling_bitnet.py
View file @
d0742860
...
@@ -1373,7 +1373,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel):
...
@@ -1373,7 +1373,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel):
cache_length
+
input_ids
.
shape
[
1
]
>
max_cache_length
):
cache_length
+
input_ids
.
shape
[
1
]
>
max_cache_length
):
attention_mask
=
attention_mask
[:,
-
max_cache_length
:]
attention_mask
=
attention_mask
[:,
-
max_cache_length
:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
position_ids
=
kwargs
.
get
(
"position_ids"
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
...
...
examples/gemm/example_gemm_autotune.py
View file @
d0742860
...
@@ -16,10 +16,7 @@ def ref_program(A, B):
...
@@ -16,10 +16,7 @@ def ref_program(A, B):
def
get_configs
(
M
,
N
,
K
,
with_roller
=
False
,
topk
=
20
):
def
get_configs
(
M
,
N
,
K
,
with_roller
=
False
,
topk
=
20
):
if
with_roller
:
if
with_roller
:
if
torch
.
version
.
hip
is
not
None
:
arch
=
CDNA
(
"cuda"
)
if
torch
.
version
.
hip
is
None
else
CUDA
(
"hip"
)
arch
=
CDNA
(
"hip"
)
else
:
arch
=
CUDA
(
"cuda"
)
carve_template
=
MatmulTemplate
(
carve_template
=
MatmulTemplate
(
M
=
M
,
M
=
M
,
N
=
N
,
N
=
N
,
...
...
src/op/gemm_sp.cc
View file @
d0742860
...
@@ -230,7 +230,7 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
...
@@ -230,7 +230,7 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
<<
" and "
<<
B
.
scope
();
<<
" and "
<<
B
.
scope
();
ICHECK
((
E
.
scope
()
==
"shared"
||
E
.
scope
()
==
"shared.dyn"
))
ICHECK
((
E
.
scope
()
==
"shared"
||
E
.
scope
()
==
"shared.dyn"
))
<<
"Only support shared.dyn scope for E as copy from smem to rmem are "
<<
"Only support shared.dyn scope for E as copy from smem to rmem are "
"delegated to cute implemntation, found "
"delegated to cute implem
e
ntation, found "
<<
E
.
scope
();
<<
E
.
scope
();
ss
<<
op_name
<<
"<"
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
", "
;
ss
<<
op_name
<<
"<"
<<
M
<<
", "
<<
N
<<
", "
<<
K
<<
", "
;
ss
<<
warp_m
<<
", "
<<
warp_n
<<
", "
;
ss
<<
warp_m
<<
", "
<<
warp_n
<<
", "
;
...
...
src/target/codegen_cpp.h
View file @
d0742860
...
@@ -95,7 +95,7 @@ private:
...
@@ -95,7 +95,7 @@ private:
Array
<
String
>
function_names_
;
Array
<
String
>
function_names_
;
/*! \brief whether to emit asserts in the resulting C code */
/*! \brief whether to emit asserts in the resulting C code */
bool
emit_asserts_
;
bool
emit_asserts_
;
/*! \brief whether to emit forwar
e
d function declarations in the resulting C
/*! \brief whether to emit forward function declarations in the resulting C
* code */
* code */
bool
emit_fwd_func_decl_
;
bool
emit_fwd_func_decl_
;
...
...
src/target/codegen_webgpu.cc
View file @
d0742860
...
@@ -252,9 +252,9 @@ CodeGenTileLangWebGPU::AddFunction(const PrimFunc &f, bool skip_readonly_decl) {
...
@@ -252,9 +252,9 @@ CodeGenTileLangWebGPU::AddFunction(const PrimFunc &f, bool skip_readonly_decl) {
os_param_access
<<
"]"
;
os_param_access
<<
"]"
;
func_info
.
launch_param_tags
.
push_back
(
os_param_access
.
str
());
func_info
.
launch_param_tags
.
push_back
(
os_param_access
.
str
());
ICHECK
(
!
info
.
has_block_index_z
)
ICHECK
(
!
info
.
has_block_index_z
)
<<
"blockIdx.z is not supported in WebGPU to "
<<
"blockIdx.z is not supported in WebGPU to
accomodate large blockIdx.x"
;
"
accom
m
odate large blockIdx.x"
;
// anotate workgroup
// an
n
otate workgroup
this
->
stream
<<
"@compute @workgroup_size("
<<
info
.
workgroup_size
[
0
]
<<
", "
this
->
stream
<<
"@compute @workgroup_size("
<<
info
.
workgroup_size
[
0
]
<<
", "
<<
info
.
workgroup_size
[
1
]
<<
", "
<<
info
.
workgroup_size
[
2
]
<<
info
.
workgroup_size
[
1
]
<<
", "
<<
info
.
workgroup_size
[
2
]
<<
")
\n
"
;
<<
")
\n
"
;
...
...
src/tl_templates/cpp/half.hpp
View file @
d0742860
...
@@ -284,7 +284,7 @@
...
@@ -284,7 +284,7 @@
#endif
#endif
#ifndef HALF_ENABLE_F16C_INTRINSICS
#ifndef HALF_ENABLE_F16C_INTRINSICS
/// Enable F16C intruction set intrinsics.
/// Enable F16C in
s
truction set intrinsics.
/// Defining this to 1 enables the use of [F16C compiler
/// Defining this to 1 enables the use of [F16C compiler
/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between
/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between
/// half-precision and single-precision values which may result in improved
/// half-precision and single-precision values which may result in improved
...
@@ -1674,7 +1674,7 @@ template <typename T> T half2float(unsigned int value) {
...
@@ -1674,7 +1674,7 @@ template <typename T> T half2float(unsigned int value) {
/// \tparam R rounding mode to use
/// \tparam R rounding mode to use
/// \tparam E `true` for round to even, `false` for round away from zero
/// \tparam E `true` for round to even, `false` for round away from zero
/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never
/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never
/// raise it \tparam T type to convert to (bui
t
lin integer type with at least 16
/// raise it \tparam T type to convert to (buil
t
in integer type with at least 16
/// bits precision, excluding any implicit sign bits) \param value
/// bits precision, excluding any implicit sign bits) \param value
/// half-precision value to convert \return rounded integer value \exception
/// half-precision value to convert \return rounded integer value \exception
/// FE_INVALID if value is not representable in type \a T \exception FE_INEXACT
/// FE_INVALID if value is not representable in type \a T \exception FE_INEXACT
...
@@ -1778,7 +1778,7 @@ inline uint32 divide64(uint32 x, uint32 y, int &s) {
...
@@ -1778,7 +1778,7 @@ inline uint32 divide64(uint32 x, uint32 y, int &s) {
/// \tparam R `true` to compute signed remainder, `false` for positive remainder
/// \tparam R `true` to compute signed remainder, `false` for positive remainder
/// \param x first operand as positive finite half-precision value
/// \param x first operand as positive finite half-precision value
/// \param y second operand as positive finite half-precision value
/// \param y second operand as positive finite half-precision value
/// \param quo adress to store quotient at, `nullptr` if \a Q `false`
/// \param quo ad
d
ress to store quotient at, `nullptr` if \a Q `false`
/// \return modulus of \a x / \a y
/// \return modulus of \a x / \a y
template
<
bool
Q
,
bool
R
>
template
<
bool
Q
,
bool
R
>
unsigned
int
mod
(
unsigned
int
x
,
unsigned
int
y
,
int
*
quo
=
NULL
)
{
unsigned
int
mod
(
unsigned
int
x
,
unsigned
int
y
,
int
*
quo
=
NULL
)
{
...
@@ -2435,7 +2435,7 @@ template <typename, typename, std::float_round_style> struct half_caster;
...
@@ -2435,7 +2435,7 @@ template <typename, typename, std::float_round_style> struct half_caster;
/// Half-precision floating-point type.
/// Half-precision floating-point type.
/// This class implements an IEEE-conformant half-precision floating-point type
/// This class implements an IEEE-conformant half-precision floating-point type
/// with the usual arithmetic operators and conversions. It is implicitly
/// with the usual arithmetic operators and conversions. It is implicitly
/// convertible to single-precision floating-point, which makes ar
t
ihmetic
/// convertible to single-precision floating-point, which makes ari
t
hmetic
/// expressions and functions with mixed-type operands to be of the most precise
/// expressions and functions with mixed-type operands to be of the most precise
/// operand type.
/// operand type.
///
///
...
@@ -2445,9 +2445,9 @@ template <typename, typename, std::float_round_style> struct half_caster;
...
@@ -2445,9 +2445,9 @@ template <typename, typename, std::float_round_style> struct half_caster;
/// which means it can be standard-conformantly copied using raw binary copies.
/// which means it can be standard-conformantly copied using raw binary copies.
/// But in this context some more words about the actual size of the type.
/// But in this context some more words about the actual size of the type.
/// Although the half is representing an IEEE 16-bit type, it does not
/// Although the half is representing an IEEE 16-bit type, it does not
/// nec
c
essarily have to be of exactly 16-bits size. But on any reasonable
/// necessarily have to be of exactly 16-bits size. But on any reasonable
/// implementation the actual binary representation of this type will most
/// implementation the actual binary representation of this type will most
/// probably not ivolve any additional "magic" or padding beyond the simple
/// probably not i
n
volve any additional "magic" or padding beyond the simple
/// binary representation of the underlying 16-bit IEEE number, even if not
/// binary representation of the underlying 16-bit IEEE number, even if not
/// strictly guaranteed by the standard. But even then it only has an actual
/// strictly guaranteed by the standard. But even then it only has an actual
/// size of 16 bits if your C++ implementation supports an unsigned integer type
/// size of 16 bits if your C++ implementation supports an unsigned integer type
...
@@ -2801,7 +2801,7 @@ public:
...
@@ -2801,7 +2801,7 @@ public:
static
HALF_CONSTEXPR_CONST
bool
traps
=
true
;
static
HALF_CONSTEXPR_CONST
bool
traps
=
true
;
#else
#else
/// Traps only if [HALF_ERRHANDLING_THROW_...](\ref
/// Traps only if [HALF_ERRHANDLING_THROW_...](\ref
/// HALF_ERRHANDLING_THROW_INVALID) is ac
i
tvated.
/// HALF_ERRHANDLING_THROW_INVALID) is act
i
vated.
static
HALF_CONSTEXPR_CONST
bool
traps
=
false
;
static
HALF_CONSTEXPR_CONST
bool
traps
=
false
;
#endif
#endif
...
@@ -5067,7 +5067,7 @@ inline half frexp(half arg, int *exp) {
...
@@ -5067,7 +5067,7 @@ inline half frexp(half arg, int *exp) {
/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn).
/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn).
/// \param arg number to modify
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
/// \return \a arg mult
i
plied by 2 raised to \a exp
/// \exception FE_INVALID for signaling NaN
/// \exception FE_INVALID for signaling NaN
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
inline
half
scalbln
(
half
arg
,
long
exp
)
{
inline
half
scalbln
(
half
arg
,
long
exp
)
{
...
@@ -5096,7 +5096,7 @@ inline half scalbln(half arg, long exp) {
...
@@ -5096,7 +5096,7 @@ inline half scalbln(half arg, long exp) {
/// **See also:** Documentation for
/// **See also:** Documentation for
/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). \param
/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). \param
/// arg number to modify \param exp power of two to multiply with \return \a arg
/// arg number to modify \param exp power of two to multiply with \return \a arg
/// multplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN
/// mult
i
plied by 2 raised to \a exp \exception FE_INVALID for signaling NaN
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
inline
half
scalbn
(
half
arg
,
int
exp
)
{
return
scalbln
(
arg
,
exp
);
}
inline
half
scalbn
(
half
arg
,
int
exp
)
{
return
scalbln
(
arg
,
exp
);
}
...
@@ -5106,7 +5106,7 @@ inline half scalbn(half arg, int exp) { return scalbln(arg, exp); }
...
@@ -5106,7 +5106,7 @@ inline half scalbn(half arg, int exp) { return scalbln(arg, exp); }
/// **See also:** Documentation for
/// **See also:** Documentation for
/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). \param
/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). \param
/// arg number to modify \param exp power of two to multiply with \return \a arg
/// arg number to modify \param exp power of two to multiply with \return \a arg
/// multplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN
/// mult
i
plied by 2 raised to \a exp \exception FE_INVALID for signaling NaN
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
inline
half
ldexp
(
half
arg
,
int
exp
)
{
return
scalbln
(
arg
,
exp
);
}
inline
half
ldexp
(
half
arg
,
int
exp
)
{
return
scalbln
(
arg
,
exp
);
}
...
@@ -5379,7 +5379,7 @@ inline HALF_CONSTEXPR bool islessequal(half x, half y) {
...
@@ -5379,7 +5379,7 @@ inline HALF_CONSTEXPR bool islessequal(half x, half y) {
!
isnan
(
x
)
&&
!
isnan
(
y
);
!
isnan
(
x
)
&&
!
isnan
(
y
);
}
}
/// Quiet comarison for less or greater.
/// Quiet com
p
arison for less or greater.
/// **See also:** Documentation for
/// **See also:** Documentation for
/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater).
/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater).
/// \param x first operand
/// \param x first operand
...
@@ -5503,7 +5503,7 @@ inline int feraiseexcept(int excepts) {
...
@@ -5503,7 +5503,7 @@ inline int feraiseexcept(int excepts) {
///
///
/// **See also:** Documentation for
/// **See also:** Documentation for
/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag).
/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag).
/// \param flagp adress to store flag state at
/// \param flagp ad
d
ress to store flag state at
/// \param excepts OR of flags to save
/// \param excepts OR of flags to save
/// \retval 0 for success
/// \retval 0 for success
inline
int
fegetexceptflag
(
int
*
flagp
,
int
excepts
)
{
inline
int
fegetexceptflag
(
int
*
flagp
,
int
excepts
)
{
...
@@ -5520,7 +5520,7 @@ inline int fegetexceptflag(int *flagp, int excepts) {
...
@@ -5520,7 +5520,7 @@ inline int fegetexceptflag(int *flagp, int excepts) {
///
///
/// **See also:** Documentation for
/// **See also:** Documentation for
/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag).
/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag).
/// \param flagp adress to take flag state from
/// \param flagp ad
d
ress to take flag state from
/// \param excepts OR of flags to restore
/// \param excepts OR of flags to restore
/// \retval 0 for success
/// \retval 0 for success
inline
int
fesetexceptflag
(
const
int
*
flagp
,
int
excepts
)
{
inline
int
fesetexceptflag
(
const
int
*
flagp
,
int
excepts
)
{
...
...
src/tl_templates/cuda/common.h
View file @
d0742860
...
@@ -48,7 +48,7 @@ using int4_t = int4;
...
@@ -48,7 +48,7 @@ using int4_t = int4;
} \
} \
} while (0)
} while (0)
// abs function for bfloat_t and half_t since there is no implicit conver
t
ion
// abs function for bfloat_t and half_t since there is no implicit conver
s
ion
// method
// method
TL_PATCH
TL_DEVICE
half_t
__habs
(
const
half_t
x
)
{
TL_PATCH
TL_DEVICE
half_t
__habs
(
const
half_t
x
)
{
return
half_t
(
__habs
(
x
.
to_half
()));
return
half_t
(
__habs
(
x
.
to_half
()));
...
...
src/tl_templates/cuda/debug.h
View file @
d0742860
...
@@ -118,7 +118,7 @@ debug_print_buffer_value<signed char>(const char *msg, const char *buf_name,
...
@@ -118,7 +118,7 @@ debug_print_buffer_value<signed char>(const char *msg, const char *buf_name,
threadIdx
.
z
,
buf_name
,
index
,
var
);
threadIdx
.
z
,
buf_name
,
index
,
var
);
}
}
// Specialization for unsiged char type
// Specialization for unsig
n
ed char type
template
<
>
template
<
>
__device__
void
__device__
void
debug_print_buffer_value
<
unsigned
char
>
(
const
char
*
msg
,
const
char
*
buf_name
,
debug_print_buffer_value
<
unsigned
char
>
(
const
char
*
msg
,
const
char
*
buf_name
,
...
...
src/transform/atomicadd_vectorize.cc
View file @
d0742860
/*!
/*!
* \file atomicadd_vectorize.cc
* \file atomicadd_vectorize.cc
* \brief A tool to atomatically vectorize atomic add
* \brief A tool to a
u
tomatically vectorize atomic add
*/
*/
#include "../layout/layout.h"
#include "../layout/layout.h"
...
...
src/transform/merge_shared_memory_allocations.cc
View file @
d0742860
...
@@ -303,7 +303,7 @@ private:
...
@@ -303,7 +303,7 @@ private:
bool
IsAppropriateSharedMemory
(
const
Var
&
var
)
{
bool
IsAppropriateSharedMemory
(
const
Var
&
var
)
{
return
is_dynamic_
?
IsDynamicSharedMemory
(
var
)
:
IsStaticSharedMemory
(
var
);
return
is_dynamic_
?
IsDynamicSharedMemory
(
var
)
:
IsStaticSharedMemory
(
var
);
}
}
// Whether do dy
a
nmic analysis.
// Whether do dyn
a
mic analysis.
bool
is_dynamic_
{
true
};
bool
is_dynamic_
{
true
};
// Whether do aggressive merge.
// Whether do aggressive merge.
bool
enable_aggressive_merge_
{
false
};
bool
enable_aggressive_merge_
{
false
};
...
@@ -435,7 +435,7 @@ private:
...
@@ -435,7 +435,7 @@ private:
const
AllocateNode
*
alloc
=
shmem_allocs_
[
buffer
];
const
AllocateNode
*
alloc
=
shmem_allocs_
[
buffer
];
auto
alignment
=
align
[
i
];
auto
alignment
=
align
[
i
];
// Modern nvidia architecture performs hardware swizzling (hopper
// Modern nvidia architecture performs hardware swizzling (hopper
// wgmma/tma for ex
m
aple) requires dynamic shared memory address to
// wgmma/tma for exa
m
ple) requires dynamic shared memory address to
// be aligned to 1024 bytes For other devices, we align to 16 bytes
// be aligned to 1024 bytes For other devices, we align to 16 bytes
if
(
shmem_alignment_map_
.
find
(
buffer
)
!=
if
(
shmem_alignment_map_
.
find
(
buffer
)
!=
shmem_alignment_map_
.
end
())
{
shmem_alignment_map_
.
end
())
{
...
@@ -943,7 +943,7 @@ private:
...
@@ -943,7 +943,7 @@ private:
*/
*/
StorageEntry
*
NewAlloc
(
const
AllocateNode
*
op
,
size_t
const_nbits
)
{
StorageEntry
*
NewAlloc
(
const
AllocateNode
*
op
,
size_t
const_nbits
)
{
ICHECK
(
op
!=
nullptr
);
ICHECK
(
op
!=
nullptr
);
// Re
-
use not successful, allocate a new buffer.
// Reuse not successful, allocate a new buffer.
StorageEntry
*
entry
=
arena_
.
make
<
StorageEntry
>
();
StorageEntry
*
entry
=
arena_
.
make
<
StorageEntry
>
();
entry
->
allocs
.
push_back
({
op
->
buffer_var
.
get
()});
entry
->
allocs
.
push_back
({
op
->
buffer_var
.
get
()});
entry
->
const_nbits
=
const_nbits
;
entry
->
const_nbits
=
const_nbits
;
...
@@ -1046,7 +1046,7 @@ private:
...
@@ -1046,7 +1046,7 @@ private:
sym_free_list_
.
push_back
(
e
);
sym_free_list_
.
push_back
(
e
);
}
}
}
}
// Whe
a
ther enable dy
a
nmic analysis.
// Whether enable dyn
a
mic analysis.
bool
is_dynamic_
{
true
};
bool
is_dynamic_
{
true
};
// Whether enable verbose logging.
// Whether enable verbose logging.
...
...
src/transform/storage_rewrite.cc
View file @
d0742860
...
@@ -140,9 +140,9 @@ public:
...
@@ -140,9 +140,9 @@ public:
//
//
class
LinearAccessPatternFinder
final
:
public
StmtExprVisitor
{
class
LinearAccessPatternFinder
final
:
public
StmtExprVisitor
{
public:
public:
/*! \brief record the touch hist of statment. */
/*! \brief record the touch hist of stat
e
ment. */
struct
StmtEntry
{
struct
StmtEntry
{
// The statment
// The stat
e
ment
const
Object
*
stmt
;
const
Object
*
stmt
;
// The index in the linear_seq_ to point to end of the nested scope.
// The index in the linear_seq_ to point to end of the nested scope.
// This is only set to non-zero if stmt is a nested scope.
// This is only set to non-zero if stmt is a nested scope.
...
@@ -150,7 +150,7 @@ public:
...
@@ -150,7 +150,7 @@ public:
// offset if offset < 0, means this is the end, the begin entry is
// offset if offset < 0, means this is the end, the begin entry is
// current_index + offset
// current_index + offset
int64_t
scope_pair_offset
{
0
};
int64_t
scope_pair_offset
{
0
};
// The buffer variables this statment touched.
// The buffer variables this stat
e
ment touched.
std
::
vector
<
const
VarNode
*>
touched
;
std
::
vector
<
const
VarNode
*>
touched
;
};
};
// The scope of each allocation
// The scope of each allocation
...
@@ -675,7 +675,7 @@ private:
...
@@ -675,7 +675,7 @@ private:
scope
.
tag
!=
".workspace"
&&
scope
.
tag
!=
".vtcm"
;
scope
.
tag
!=
".workspace"
&&
scope
.
tag
!=
".vtcm"
;
}
}
// All
l
ocate entry of node.
// Allocate entry of node.
// Event entry in liveness analysis
// Event entry in liveness analysis
struct
EventEntry
{
struct
EventEntry
{
// variables we generate
// variables we generate
...
@@ -785,10 +785,10 @@ private:
...
@@ -785,10 +785,10 @@ private:
for
(
const
AllocateNode
*
op
:
e
->
allocs
)
{
for
(
const
AllocateNode
*
op
:
e
->
allocs
)
{
ICHECK_EQ
(
op
->
extents
.
size
(),
1
)
ICHECK_EQ
(
op
->
extents
.
size
(),
1
)
<<
"Buffer var "
<<
op
->
buffer_var
->
name_hint
<<
"Buffer var "
<<
op
->
buffer_var
->
name_hint
<<
" was identified as a re
-
usable allocation, but has "
<<
" was identified as a reusable allocation, but has "
<<
op
->
extents
.
size
()
<<
" physical dimensions. "
<<
op
->
extents
.
size
()
<<
" physical dimensions. "
<<
"Currently, only flat 1-d memory spaces should be "
<<
"Currently, only flat 1-d memory spaces should be "
"identified as re
-
usable "
"identified as reusable "
"allocations."
;
"allocations."
;
PrimExpr
sz
=
op
->
extents
[
0
];
PrimExpr
sz
=
op
->
extents
[
0
];
auto
nbits
=
op
->
dtype
.
bits
()
*
op
->
dtype
.
lanes
();
auto
nbits
=
op
->
dtype
.
bits
()
*
op
->
dtype
.
lanes
();
...
@@ -905,7 +905,7 @@ private:
...
@@ -905,7 +905,7 @@ private:
void
PlanNewScope
(
const
Object
*
op
)
{
void
PlanNewScope
(
const
Object
*
op
)
{
if
(
thread_scope_
!=
nullptr
)
{
if
(
thread_scope_
!=
nullptr
)
{
ICHECK
(
thread_scope_
==
op
);
ICHECK
(
thread_scope_
==
op
);
// erase all memory at
a
tched to this scope.
// erase all memory att
a
ched to this scope.
for
(
auto
it
=
const_free_map_
.
begin
();
it
!=
const_free_map_
.
end
();)
{
for
(
auto
it
=
const_free_map_
.
begin
();
it
!=
const_free_map_
.
end
();)
{
if
(
it
->
second
->
attach_scope_
==
op
)
{
if
(
it
->
second
->
attach_scope_
==
op
)
{
it
=
const_free_map_
.
erase
(
it
);
it
=
const_free_map_
.
erase
(
it
);
...
@@ -1023,7 +1023,7 @@ private:
...
@@ -1023,7 +1023,7 @@ private:
StorageEntry
*
NewAlloc
(
const
AllocateNode
*
op
,
const
Object
*
attach_scope
,
StorageEntry
*
NewAlloc
(
const
AllocateNode
*
op
,
const
Object
*
attach_scope
,
const
StorageScope
&
scope
,
size_t
const_nbits
)
{
const
StorageScope
&
scope
,
size_t
const_nbits
)
{
ICHECK
(
op
!=
nullptr
);
ICHECK
(
op
!=
nullptr
);
// Re
-
use not successful, allocate a new buffer.
// Reuse not successful, allocate a new buffer.
auto
entry
=
std
::
make_unique
<
StorageEntry
>
();
auto
entry
=
std
::
make_unique
<
StorageEntry
>
();
entry
->
attach_scope_
=
attach_scope
;
entry
->
attach_scope_
=
attach_scope
;
entry
->
scope
=
scope
;
entry
->
scope
=
scope
;
...
@@ -1050,7 +1050,7 @@ private:
...
@@ -1050,7 +1050,7 @@ private:
// have its own allocation with size determined at runtime.
// have its own allocation with size determined at runtime.
bool
is_known_size
=
(
const_nbits
!=
0
);
bool
is_known_size
=
(
const_nbits
!=
0
);
// Currently, only flat memory spaces can be re
-
used. Packing
// Currently, only flat memory spaces can be reused. Packing
// into N-d space (e.g. 2-d texture memory on GPUs) will require
// into N-d space (e.g. 2-d texture memory on GPUs) will require
// more in-depth algorithms.
// more in-depth algorithms.
bool
is_flat_memory_space
=
(
num_physical_dimensions
==
1
);
bool
is_flat_memory_space
=
(
num_physical_dimensions
==
1
);
...
...
src/transform/thread_storage_sync.cc
View file @
d0742860
...
@@ -189,7 +189,7 @@ protected:
...
@@ -189,7 +189,7 @@ protected:
}
}
}
}
}
}
// return the exposed entries, remove unecessary ones.
// return the exposed entries, remove un
n
ecessary ones.
int
sync_count
=
0
;
int
sync_count
=
0
;
// head are before first sync, tail are after last sync
// head are before first sync, tail are after last sync
std
::
vector
<
AccessEntry
>
head
,
tail
;
std
::
vector
<
AccessEntry
>
head
,
tail
;
...
...
src/transform/vectorize_loop.cc
View file @
d0742860
...
@@ -527,7 +527,7 @@ public:
...
@@ -527,7 +527,7 @@ public:
// A single var can be binded in multiple lets
// A single var can be binded in multiple lets
// but they have to bind to the same value.
// but they have to bind to the same value.
// This is used to allow cases when we reuse a single let
// This is used to allow cases when we reuse a single let
// expression to co
s
ntruct a nested expr.
// expression to con
s
truct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto
it
=
let_binding_
.
find
(
op
->
var
);
auto
it
=
let_binding_
.
find
(
op
->
var
);
if
(
it
!=
let_binding_
.
end
())
{
if
(
it
!=
let_binding_
.
end
())
{
...
@@ -683,7 +683,7 @@ public:
...
@@ -683,7 +683,7 @@ public:
return
StmtMutator
::
VisitStmt_
(
op
);
return
StmtMutator
::
VisitStmt_
(
op
);
}
}
// scalarize the statment
// scalarize the stat
e
ment
Stmt
Scalarize
(
Stmt
stmt
)
{
Stmt
Scalarize
(
Stmt
stmt
)
{
Var
idx
(
var_
->
name_hint
+
".s"
,
var_
->
dtype
);
Var
idx
(
var_
->
name_hint
+
".s"
,
var_
->
dtype
);
stmt
=
Substitute
(
stmt
,
{{
var_
,
idx
}});
stmt
=
Substitute
(
stmt
,
{{
var_
,
idx
}});
...
@@ -701,7 +701,7 @@ private:
...
@@ -701,7 +701,7 @@ private:
PrimExpr
var_lanes_
;
PrimExpr
var_lanes_
;
// ramp representing the var.
// ramp representing the var.
PrimExpr
ramp_
;
PrimExpr
ramp_
;
// flag to mark requirment of scalarization.
// flag to mark requir
e
ment of scalarization.
bool
need_scalarize_
{
false
};
bool
need_scalarize_
{
false
};
// Let binding
// Let binding
std
::
unordered_map
<
Var
,
PrimExpr
,
ObjectPtrHash
,
ObjectPtrEqual
>
let_binding_
;
std
::
unordered_map
<
Var
,
PrimExpr
,
ObjectPtrHash
,
ObjectPtrEqual
>
let_binding_
;
...
...
testing/python/language/test_tilelang_language_reshape.py
View file @
d0742860
...
@@ -88,6 +88,7 @@ def reshape_test_smem_2d_2_1d(N, M, dtype):
...
@@ -88,6 +88,7 @@ def reshape_test_smem_2d_2_1d(N, M, dtype):
return
main
return
main
def
run_reshape_smem_2d_2_1d
(
N
,
M
,
dtype
):
def
run_reshape_smem_2d_2_1d
(
N
,
M
,
dtype
):
program
=
reshape_test_smem_2d_2_1d
(
N
,
M
,
dtype
)
program
=
reshape_test_smem_2d_2_1d
(
N
,
M
,
dtype
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
)
jit_kernel
=
tl
.
compile
(
program
,
out_idx
=-
1
)
...
@@ -98,11 +99,11 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
...
@@ -98,11 +99,11 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
profiler
.
assert_allclose
(
ref_program
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_reshape_smem_2d_2_1d
():
def
test_reshape_smem_2d_2_1d
():
run_reshape_smem_2d_2_1d
(
1024
,
32
,
"float32"
)
run_reshape_smem_2d_2_1d
(
1024
,
32
,
"float32"
)
run_reshape_smem_2d_2_1d
(
2048
,
64
,
"float16"
)
run_reshape_smem_2d_2_1d
(
2048
,
64
,
"float16"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
tilelang/autotuner/tuner.py
View file @
d0742860
...
@@ -203,7 +203,7 @@ class AutoTuner:
...
@@ -203,7 +203,7 @@ class AutoTuner:
logger
.
warning
(
logger
.
warning
(
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
)
)
supply_prog
=
lambda
_
:
get_autotune_inputs
()
# noqa: E731
·
supply_prog
=
lambda
_
:
get_autotune_inputs
()
# noqa: E731
self
.
profile_args
=
ProfileArgs
(
self
.
profile_args
=
ProfileArgs
(
supply_type
=
supply_type
,
supply_type
=
supply_type
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment