Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bcae814e
Unverified
Commit
bcae814e
authored
Dec 15, 2025
by
Xiangwen Wang
Committed by
GitHub
Dec 15, 2025
Browse files
Enhance vectorized conversion support (#1438)
parent
e387102c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
123 additions
and
37 deletions
+123
-37
src/target/codegen_cuda.cc
src/target/codegen_cuda.cc
+63
-0
src/tl_templates/cuda/cuda_fp8.h
src/tl_templates/cuda/cuda_fp8.h
+39
-32
src/transform/layout_inference.cc
src/transform/layout_inference.cc
+9
-2
testing/python/language/test_tilelang_language_vectorized_cast.py
...python/language/test_tilelang_language_vectorized_cast.py
+12
-3
No files found.
src/target/codegen_cuda.cc
View file @
bcae814e
...
...
@@ -1139,6 +1139,69 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}
}
if
((
from_ty
.
is_float8_e4m3
()
||
from_ty
.
is_float8_e5m2
())
&&
target_ty
.
is_float
())
{
// FP8 -> FP32: Use __tl_cvt_fp8x2_to_float2 for vectorized conversion
// (fp8x2 -> float2)
if
(
from_ty
.
lanes
()
==
2
&&
target_ty
.
lanes
()
==
2
)
{
// fp8x2 -> float2
PrintIndent
();
stream
<<
"*reinterpret_cast<float2*>(&("
<<
sret
<<
")) = "
"__tl_cvt_fp8x2_to_float2(*reinterpret_cast<__nv_fp8x2_storage_"
"t*>(&("
<<
src
<<
")), "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
os
<<
sret
;
return
;
}
else
if
(
from_ty
.
lanes
()
==
4
&&
target_ty
.
lanes
()
==
4
)
{
// fp8x4 -> float4
PrintIndent
();
stream
<<
"*(float2*)(&"
<<
sret
<<
") = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[0], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"*((float2*)(&"
<<
sret
<<
")+1) = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[1], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
os
<<
sret
;
return
;
}
else
if
(
from_ty
.
lanes
()
==
8
&&
target_ty
.
lanes
()
==
8
)
{
// fp8x8 -> float8
PrintIndent
();
stream
<<
"*(float2*)(&"
<<
sret
<<
") = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[0], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"*((float2*)(&"
<<
sret
<<
")+1) = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[1], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"*((float2*)(&"
<<
sret
<<
")+2) = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[2], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"*((float2*)(&"
<<
sret
<<
")+3) = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[3], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
os
<<
sret
;
return
;
}
}
// Fallback: elementwise cast
for
(
int
i
=
0
,
lanes
=
from_ty
.
lanes
();
i
<
lanes
;
++
i
)
{
std
::
ostringstream
val
;
...
...
src/tl_templates/cuda/cuda_fp8.h
View file @
bcae814e
...
...
@@ -33,7 +33,7 @@ struct __CUDA_ALIGN__(32) fp8_e4_32_t {
fp8_e4_16_t
x
;
fp8_e4_16_t
y
;
__device__
__forceinline__
fp8_e4_32_t
&
operator
=
(
const
ulonglong4
&
rhs
)
{
TL_DEVICE
fp8_e4_32_t
&
operator
=
(
const
ulonglong4
&
rhs
)
{
x
.
x
=
*
(
fp8_e4_8_t
*
)
&
rhs
.
x
;
x
.
y
=
*
(
fp8_e4_8_t
*
)
&
rhs
.
y
;
y
.
x
=
*
(
fp8_e4_8_t
*
)
&
rhs
.
z
;
...
...
@@ -68,7 +68,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
fp8_e5_16_t
x
;
fp8_e5_16_t
y
;
__device__
__forceinline__
fp8_e5_32_t
&
operator
=
(
const
ulonglong4
&
rhs
)
{
TL_DEVICE
fp8_e5_32_t
&
operator
=
(
const
ulonglong4
&
rhs
)
{
x
.
x
=
*
(
fp8_e5_8_t
*
)
&
rhs
.
x
;
x
.
y
=
*
(
fp8_e5_8_t
*
)
&
rhs
.
y
;
y
.
x
=
*
(
fp8_e5_8_t
*
)
&
rhs
.
z
;
...
...
@@ -78,7 +78,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
};
// Pack two fp8_e4_t values.
__forceinline__
__device__
fp8_e4_2_t
make_fp8_e4_2_t
(
fp8_e4_t
x
,
fp8_e4_t
y
)
{
TL_DEVICE
fp8_e4_2_t
make_fp8_e4_2_t
(
fp8_e4_t
x
,
fp8_e4_t
y
)
{
fp8_e4_2_t
result
;
result
.
x
=
x
;
result
.
y
=
y
;
...
...
@@ -86,9 +86,8 @@ __forceinline__ __device__ fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) {
}
// Pack four fp8_e4_t values.
__forceinline__
__device__
fp8_e4_4_t
make_fp8_e4_4_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x3
)
{
TL_DEVICE
fp8_e4_4_t
make_fp8_e4_4_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x3
)
{
fp8_e4_4_t
result
;
result
.
x
=
x0
;
result
.
y
=
x1
;
...
...
@@ -98,11 +97,9 @@ __forceinline__ __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1,
}
// Pack eight fp8_e4_t values.
__forceinline__
__device__
fp8_e4_8_t
make_fp8_e4_8_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x3
,
fp8_e4_t
x4
,
fp8_e4_t
x5
,
fp8_e4_t
x6
,
fp8_e4_t
x7
)
{
TL_DEVICE
fp8_e4_8_t
make_fp8_e4_8_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x3
,
fp8_e4_t
x4
,
fp8_e4_t
x5
,
fp8_e4_t
x6
,
fp8_e4_t
x7
)
{
fp8_e4_8_t
result
;
result
.
x
=
make_fp8_e4_4_t
(
x0
,
x1
,
x2
,
x3
);
result
.
y
=
make_fp8_e4_4_t
(
x4
,
x5
,
x6
,
x7
);
...
...
@@ -110,11 +107,12 @@ __forceinline__ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1,
}
// Pack sixteen fp8_e4_t values.
__forceinline__
__device__
fp8_e4_16_t
make_fp8_e4_16_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x3
,
fp8_e4_t
x4
,
fp8_e4_t
x5
,
fp8_e4_t
x6
,
fp8_e4_t
x7
,
fp8_e4_t
y0
,
fp8_e4_t
y1
,
fp8_e4_t
y2
,
fp8_e4_t
y3
,
fp8_e4_t
y4
,
fp8_e4_t
y5
,
fp8_e4_t
y6
,
fp8_e4_t
y7
)
{
TL_DEVICE
fp8_e4_16_t
make_fp8_e4_16_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x3
,
fp8_e4_t
x4
,
fp8_e4_t
x5
,
fp8_e4_t
x6
,
fp8_e4_t
x7
,
fp8_e4_t
y0
,
fp8_e4_t
y1
,
fp8_e4_t
y2
,
fp8_e4_t
y3
,
fp8_e4_t
y4
,
fp8_e4_t
y5
,
fp8_e4_t
y6
,
fp8_e4_t
y7
)
{
fp8_e4_16_t
result
;
result
.
x
=
make_fp8_e4_8_t
(
x0
,
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
);
result
.
y
=
make_fp8_e4_8_t
(
y0
,
y1
,
y2
,
y3
,
y4
,
y5
,
y6
,
y7
);
...
...
@@ -122,7 +120,7 @@ make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3,
}
// Pack thirty-two fp8_e4_t values.
__forceinline__
__device__
fp8_e4_32_t
make_fp8_e4_32_t
(
TL_DEVICE
fp8_e4_32_t
make_fp8_e4_32_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x3
,
fp8_e4_t
x4
,
fp8_e4_t
x5
,
fp8_e4_t
x6
,
fp8_e4_t
x7
,
fp8_e4_t
x8
,
fp8_e4_t
x9
,
fp8_e4_t
x10
,
fp8_e4_t
x11
,
fp8_e4_t
x12
,
fp8_e4_t
x13
,
fp8_e4_t
x14
,
...
...
@@ -139,7 +137,7 @@ __forceinline__ __device__ fp8_e4_32_t make_fp8_e4_32_t(
}
// Pack two fp8_e5_t values.
__forceinline__
__device__
fp8_e5_2_t
make_fp8_e5_2_t
(
fp8_e5_t
x
,
fp8_e5_t
y
)
{
TL_DEVICE
fp8_e5_2_t
make_fp8_e5_2_t
(
fp8_e5_t
x
,
fp8_e5_t
y
)
{
fp8_e5_2_t
result
;
result
.
x
=
x
;
result
.
y
=
y
;
...
...
@@ -147,9 +145,8 @@ __forceinline__ __device__ fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) {
}
// Pack four fp8_e5_t values.
__forceinline__
__device__
fp8_e5_4_t
make_fp8_e5_4_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x3
)
{
TL_DEVICE
fp8_e5_4_t
make_fp8_e5_4_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x3
)
{
fp8_e5_4_t
result
;
result
.
x
=
x0
;
result
.
y
=
x1
;
...
...
@@ -159,11 +156,9 @@ __forceinline__ __device__ fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1,
}
// Pack eight fp8_e5_t values.
__forceinline__
__device__
fp8_e5_8_t
make_fp8_e5_8_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x3
,
fp8_e5_t
x4
,
fp8_e5_t
x5
,
fp8_e5_t
x6
,
fp8_e5_t
x7
)
{
TL_DEVICE
fp8_e5_8_t
make_fp8_e5_8_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x3
,
fp8_e5_t
x4
,
fp8_e5_t
x5
,
fp8_e5_t
x6
,
fp8_e5_t
x7
)
{
fp8_e5_8_t
result
;
result
.
x
=
make_fp8_e5_4_t
(
x0
,
x1
,
x2
,
x3
);
result
.
y
=
make_fp8_e5_4_t
(
x4
,
x5
,
x6
,
x7
);
...
...
@@ -171,11 +166,12 @@ __forceinline__ __device__ fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1,
}
// Pack sixteen fp8_e5_t values.
__forceinline__
__device__
fp8_e5_16_t
make_fp8_e5_16_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x3
,
fp8_e5_t
x4
,
fp8_e5_t
x5
,
fp8_e5_t
x6
,
fp8_e5_t
x7
,
fp8_e5_t
y0
,
fp8_e5_t
y1
,
fp8_e5_t
y2
,
fp8_e5_t
y3
,
fp8_e5_t
y4
,
fp8_e5_t
y5
,
fp8_e5_t
y6
,
fp8_e5_t
y7
)
{
TL_DEVICE
fp8_e5_16_t
make_fp8_e5_16_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x3
,
fp8_e5_t
x4
,
fp8_e5_t
x5
,
fp8_e5_t
x6
,
fp8_e5_t
x7
,
fp8_e5_t
y0
,
fp8_e5_t
y1
,
fp8_e5_t
y2
,
fp8_e5_t
y3
,
fp8_e5_t
y4
,
fp8_e5_t
y5
,
fp8_e5_t
y6
,
fp8_e5_t
y7
)
{
fp8_e5_16_t
result
;
result
.
x
=
make_fp8_e5_8_t
(
x0
,
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
);
result
.
y
=
make_fp8_e5_8_t
(
y0
,
y1
,
y2
,
y3
,
y4
,
y5
,
y6
,
y7
);
...
...
@@ -183,7 +179,7 @@ make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3,
}
// Pack thirty-two fp8_e5_t values.
__forceinline__
__device__
fp8_e5_32_t
make_fp8_e5_32_t
(
TL_DEVICE
fp8_e5_32_t
make_fp8_e5_32_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x3
,
fp8_e5_t
x4
,
fp8_e5_t
x5
,
fp8_e5_t
x6
,
fp8_e5_t
x7
,
fp8_e5_t
x8
,
fp8_e5_t
x9
,
fp8_e5_t
x10
,
fp8_e5_t
x11
,
fp8_e5_t
x12
,
fp8_e5_t
x13
,
fp8_e5_t
x14
,
...
...
@@ -198,3 +194,14 @@ __forceinline__ __device__ fp8_e5_32_t make_fp8_e5_32_t(
y12
,
y13
,
y14
,
y15
);
return
result
;
}
// e4m3x2 -> float2
TL_DEVICE
float2
__tl_cvt_fp8x2_to_float2
(
const
__nv_fp8x2_storage_t
x
,
const
__nv_fp8_interpretation_t
fp8_interpretation
)
{
half2
tmp
=
__nv_cvt_fp8x2_to_halfraw2
(
x
,
fp8_interpretation
);
float2
result
;
result
.
x
=
(
float
)
tmp
.
x
;
result
.
y
=
(
float
)
tmp
.
y
;
return
result
;
}
src/transform/layout_inference.cc
View file @
bcae814e
...
...
@@ -20,6 +20,7 @@
#include "../op/copy.h"
#include "../op/parallel.h"
#include "../op/region.h"
#include "../target/utils.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
...
...
@@ -1170,9 +1171,15 @@ private:
// If a cast operation exists, vectorization may still be required
bool
has_cast_operations
=
false
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
st
ore
=
obj
.
as
<
BufferStore
Node
>
())
{
if
(
const
auto
*
ca
st
=
obj
.
as
<
Cast
Node
>
())
{
// Check if this is a non-reducer store with Cast operation
if
(
store
->
value
.
as
<
CastNode
>
())
{
DataType
src_type
=
cast
->
value
.
dtype
();
DataType
dst_type
=
cast
->
dtype
;
bool
src_ok
=
src_type
.
is_float
()
||
src_type
.
is_bfloat
()
||
src_type
.
is_float8_e4m3
()
||
src_type
.
is_float8_e5m2
();
bool
dst_ok
=
dst_type
.
is_float
()
||
dst_type
.
is_bfloat
()
||
dst_type
.
is_float8_e4m3
()
||
dst_type
.
is_float8_e5m2
();
if
(
src_ok
&&
dst_ok
&&
TargetIsCuda
(
Target
::
Current
()))
{
has_cast_operations
=
true
;
}
}
...
...
testing/python/language/test_tilelang_language_vectorized_cast.py
View file @
bcae814e
...
...
@@ -60,9 +60,10 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
kernel
=
vectorized_cast_kernel
(
M
,
src_dtype_str
,
dst_dtype_str
)
kernel_parallel
=
parallel_vectorized_cast_kernel
(
M
,
src_dtype_str
,
dst_dtype_str
)
A
=
torch
.
randn
(
M
,
dtype
=
str2dtype
[
src_dtype_str
]).
cuda
()
B
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
]).
cuda
()
C
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
]).
cuda
()
A_float
=
torch
.
randn
(
M
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
A
=
A_float
.
to
(
str2dtype
[
src_dtype_str
])
B
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
],
device
=
"cuda"
)
C
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
],
device
=
"cuda"
)
kernel
(
A
,
B
)
kernel_parallel
(
A
,
C
)
...
...
@@ -101,6 +102,14 @@ def test_vectorized_cast():
run_vectorized_cast
(
"bfloat16"
,
"float32"
,
"__bfloat1622float2"
,
2
)
run_vectorized_cast
(
"bfloat16"
,
"float32"
,
"__bfloat1622float2"
,
4
)
# fp8_e4m3 -> fp32
run_vectorized_cast
(
"float8_e4m3"
,
"float32"
,
"__tl_cvt_fp8x2_to_float2"
,
2
)
run_vectorized_cast
(
"float8_e4m3"
,
"float32"
,
"__tl_cvt_fp8x2_to_float2"
,
4
)
# fp8_e5m2 -> fp32
run_vectorized_cast
(
"float8_e5m2"
,
"float32"
,
"__tl_cvt_fp8x2_to_float2"
,
2
)
run_vectorized_cast
(
"float8_e5m2"
,
"float32"
,
"__tl_cvt_fp8x2_to_float2"
,
4
)
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment