Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
bcae814e
"docs/source/en/using-diffusers/overview_techniques.md" did not exist on "595ba6f786510aaea77219ba4a2647255a41ede1"
Unverified
Commit
bcae814e
authored
Dec 15, 2025
by
Xiangwen Wang
Committed by
GitHub
Dec 15, 2025
Browse files
Enhance vectorized conversion support (#1438)
parent
e387102c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
123 additions
and
37 deletions
+123
-37
src/target/codegen_cuda.cc
src/target/codegen_cuda.cc
+63
-0
src/tl_templates/cuda/cuda_fp8.h
src/tl_templates/cuda/cuda_fp8.h
+39
-32
src/transform/layout_inference.cc
src/transform/layout_inference.cc
+9
-2
testing/python/language/test_tilelang_language_vectorized_cast.py
...python/language/test_tilelang_language_vectorized_cast.py
+12
-3
No files found.
src/target/codegen_cuda.cc
View file @
bcae814e
...
@@ -1139,6 +1139,69 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
...
@@ -1139,6 +1139,69 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}
}
}
}
if
((
from_ty
.
is_float8_e4m3
()
||
from_ty
.
is_float8_e5m2
())
&&
target_ty
.
is_float
())
{
// FP8 -> FP32: Use __tl_cvt_fp8x2_to_float2 for vectorized conversion
// (fp8x2 -> float2)
if
(
from_ty
.
lanes
()
==
2
&&
target_ty
.
lanes
()
==
2
)
{
// fp8x2 -> float2
PrintIndent
();
stream
<<
"*reinterpret_cast<float2*>(&("
<<
sret
<<
")) = "
"__tl_cvt_fp8x2_to_float2(*reinterpret_cast<__nv_fp8x2_storage_"
"t*>(&("
<<
src
<<
")), "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
os
<<
sret
;
return
;
}
else
if
(
from_ty
.
lanes
()
==
4
&&
target_ty
.
lanes
()
==
4
)
{
// fp8x4 -> float4
PrintIndent
();
stream
<<
"*(float2*)(&"
<<
sret
<<
") = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[0], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"*((float2*)(&"
<<
sret
<<
")+1) = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[1], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
os
<<
sret
;
return
;
}
else
if
(
from_ty
.
lanes
()
==
8
&&
target_ty
.
lanes
()
==
8
)
{
// fp8x8 -> float8
PrintIndent
();
stream
<<
"*(float2*)(&"
<<
sret
<<
") = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[0], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"*((float2*)(&"
<<
sret
<<
")+1) = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[1], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"*((float2*)(&"
<<
sret
<<
")+2) = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[2], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
PrintIndent
();
stream
<<
"*((float2*)(&"
<<
sret
<<
")+3) = "
<<
"__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&"
<<
src
<<
"))[3], "
<<
(
from_ty
.
is_float8_e4m3
()
?
"__NV_E4M3"
:
"__NV_E5M2"
)
<<
");
\n
"
;
os
<<
sret
;
return
;
}
}
// Fallback: elementwise cast
// Fallback: elementwise cast
for
(
int
i
=
0
,
lanes
=
from_ty
.
lanes
();
i
<
lanes
;
++
i
)
{
for
(
int
i
=
0
,
lanes
=
from_ty
.
lanes
();
i
<
lanes
;
++
i
)
{
std
::
ostringstream
val
;
std
::
ostringstream
val
;
...
...
src/tl_templates/cuda/cuda_fp8.h
View file @
bcae814e
...
@@ -33,7 +33,7 @@ struct __CUDA_ALIGN__(32) fp8_e4_32_t {
...
@@ -33,7 +33,7 @@ struct __CUDA_ALIGN__(32) fp8_e4_32_t {
fp8_e4_16_t
x
;
fp8_e4_16_t
x
;
fp8_e4_16_t
y
;
fp8_e4_16_t
y
;
__device__
__forceinline__
fp8_e4_32_t
&
operator
=
(
const
ulonglong4
&
rhs
)
{
TL_DEVICE
fp8_e4_32_t
&
operator
=
(
const
ulonglong4
&
rhs
)
{
x
.
x
=
*
(
fp8_e4_8_t
*
)
&
rhs
.
x
;
x
.
x
=
*
(
fp8_e4_8_t
*
)
&
rhs
.
x
;
x
.
y
=
*
(
fp8_e4_8_t
*
)
&
rhs
.
y
;
x
.
y
=
*
(
fp8_e4_8_t
*
)
&
rhs
.
y
;
y
.
x
=
*
(
fp8_e4_8_t
*
)
&
rhs
.
z
;
y
.
x
=
*
(
fp8_e4_8_t
*
)
&
rhs
.
z
;
...
@@ -68,7 +68,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
...
@@ -68,7 +68,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
fp8_e5_16_t
x
;
fp8_e5_16_t
x
;
fp8_e5_16_t
y
;
fp8_e5_16_t
y
;
__device__
__forceinline__
fp8_e5_32_t
&
operator
=
(
const
ulonglong4
&
rhs
)
{
TL_DEVICE
fp8_e5_32_t
&
operator
=
(
const
ulonglong4
&
rhs
)
{
x
.
x
=
*
(
fp8_e5_8_t
*
)
&
rhs
.
x
;
x
.
x
=
*
(
fp8_e5_8_t
*
)
&
rhs
.
x
;
x
.
y
=
*
(
fp8_e5_8_t
*
)
&
rhs
.
y
;
x
.
y
=
*
(
fp8_e5_8_t
*
)
&
rhs
.
y
;
y
.
x
=
*
(
fp8_e5_8_t
*
)
&
rhs
.
z
;
y
.
x
=
*
(
fp8_e5_8_t
*
)
&
rhs
.
z
;
...
@@ -78,7 +78,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
...
@@ -78,7 +78,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
};
};
// Pack two fp8_e4_t values.
// Pack two fp8_e4_t values.
__forceinline__
__device__
fp8_e4_2_t
make_fp8_e4_2_t
(
fp8_e4_t
x
,
fp8_e4_t
y
)
{
TL_DEVICE
fp8_e4_2_t
make_fp8_e4_2_t
(
fp8_e4_t
x
,
fp8_e4_t
y
)
{
fp8_e4_2_t
result
;
fp8_e4_2_t
result
;
result
.
x
=
x
;
result
.
x
=
x
;
result
.
y
=
y
;
result
.
y
=
y
;
...
@@ -86,9 +86,8 @@ __forceinline__ __device__ fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) {
...
@@ -86,9 +86,8 @@ __forceinline__ __device__ fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) {
}
}
// Pack four fp8_e4_t values.
// Pack four fp8_e4_t values.
__forceinline__
__device__
fp8_e4_4_t
make_fp8_e4_4_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
TL_DEVICE
fp8_e4_4_t
make_fp8_e4_4_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x2
,
fp8_e4_t
x3
)
{
fp8_e4_t
x3
)
{
fp8_e4_4_t
result
;
fp8_e4_4_t
result
;
result
.
x
=
x0
;
result
.
x
=
x0
;
result
.
y
=
x1
;
result
.
y
=
x1
;
...
@@ -98,11 +97,9 @@ __forceinline__ __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1,
...
@@ -98,11 +97,9 @@ __forceinline__ __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1,
}
}
// Pack eight fp8_e4_t values.
// Pack eight fp8_e4_t values.
__forceinline__
__device__
fp8_e4_8_t
make_fp8_e4_8_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
TL_DEVICE
fp8_e4_8_t
make_fp8_e4_8_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x2
,
fp8_e4_t
x3
,
fp8_e4_t
x3
,
fp8_e4_t
x4
,
fp8_e4_t
x5
,
fp8_e4_t
x4
,
fp8_e4_t
x5
,
fp8_e4_t
x6
,
fp8_e4_t
x7
)
{
fp8_e4_t
x6
,
fp8_e4_t
x7
)
{
fp8_e4_8_t
result
;
fp8_e4_8_t
result
;
result
.
x
=
make_fp8_e4_4_t
(
x0
,
x1
,
x2
,
x3
);
result
.
x
=
make_fp8_e4_4_t
(
x0
,
x1
,
x2
,
x3
);
result
.
y
=
make_fp8_e4_4_t
(
x4
,
x5
,
x6
,
x7
);
result
.
y
=
make_fp8_e4_4_t
(
x4
,
x5
,
x6
,
x7
);
...
@@ -110,11 +107,12 @@ __forceinline__ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1,
...
@@ -110,11 +107,12 @@ __forceinline__ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1,
}
}
// Pack sixteen fp8_e4_t values.
// Pack sixteen fp8_e4_t values.
__forceinline__
__device__
fp8_e4_16_t
TL_DEVICE
fp8_e4_16_t
make_fp8_e4_16_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
make_fp8_e4_16_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x3
,
fp8_e4_t
x3
,
fp8_e4_t
x4
,
fp8_e4_t
x5
,
fp8_e4_t
x4
,
fp8_e4_t
x5
,
fp8_e4_t
x6
,
fp8_e4_t
x7
,
fp8_e4_t
x6
,
fp8_e4_t
x7
,
fp8_e4_t
y0
,
fp8_e4_t
y0
,
fp8_e4_t
y1
,
fp8_e4_t
y2
,
fp8_e4_t
y3
,
fp8_e4_t
y1
,
fp8_e4_t
y2
,
fp8_e4_t
y3
,
fp8_e4_t
y4
,
fp8_e4_t
y5
,
fp8_e4_t
y6
,
fp8_e4_t
y7
)
{
fp8_e4_t
y4
,
fp8_e4_t
y5
,
fp8_e4_t
y6
,
fp8_e4_t
y7
)
{
fp8_e4_16_t
result
;
fp8_e4_16_t
result
;
result
.
x
=
make_fp8_e4_8_t
(
x0
,
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
);
result
.
x
=
make_fp8_e4_8_t
(
x0
,
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
);
result
.
y
=
make_fp8_e4_8_t
(
y0
,
y1
,
y2
,
y3
,
y4
,
y5
,
y6
,
y7
);
result
.
y
=
make_fp8_e4_8_t
(
y0
,
y1
,
y2
,
y3
,
y4
,
y5
,
y6
,
y7
);
...
@@ -122,7 +120,7 @@ make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3,
...
@@ -122,7 +120,7 @@ make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3,
}
}
// Pack thirty-two fp8_e4_t values.
// Pack thirty-two fp8_e4_t values.
__forceinline__
__device__
fp8_e4_32_t
make_fp8_e4_32_t
(
TL_DEVICE
fp8_e4_32_t
make_fp8_e4_32_t
(
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x3
,
fp8_e4_t
x4
,
fp8_e4_t
x0
,
fp8_e4_t
x1
,
fp8_e4_t
x2
,
fp8_e4_t
x3
,
fp8_e4_t
x4
,
fp8_e4_t
x5
,
fp8_e4_t
x6
,
fp8_e4_t
x7
,
fp8_e4_t
x8
,
fp8_e4_t
x9
,
fp8_e4_t
x5
,
fp8_e4_t
x6
,
fp8_e4_t
x7
,
fp8_e4_t
x8
,
fp8_e4_t
x9
,
fp8_e4_t
x10
,
fp8_e4_t
x11
,
fp8_e4_t
x12
,
fp8_e4_t
x13
,
fp8_e4_t
x14
,
fp8_e4_t
x10
,
fp8_e4_t
x11
,
fp8_e4_t
x12
,
fp8_e4_t
x13
,
fp8_e4_t
x14
,
...
@@ -139,7 +137,7 @@ __forceinline__ __device__ fp8_e4_32_t make_fp8_e4_32_t(
...
@@ -139,7 +137,7 @@ __forceinline__ __device__ fp8_e4_32_t make_fp8_e4_32_t(
}
}
// Pack two fp8_e5_t values.
// Pack two fp8_e5_t values.
__forceinline__
__device__
fp8_e5_2_t
make_fp8_e5_2_t
(
fp8_e5_t
x
,
fp8_e5_t
y
)
{
TL_DEVICE
fp8_e5_2_t
make_fp8_e5_2_t
(
fp8_e5_t
x
,
fp8_e5_t
y
)
{
fp8_e5_2_t
result
;
fp8_e5_2_t
result
;
result
.
x
=
x
;
result
.
x
=
x
;
result
.
y
=
y
;
result
.
y
=
y
;
...
@@ -147,9 +145,8 @@ __forceinline__ __device__ fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) {
...
@@ -147,9 +145,8 @@ __forceinline__ __device__ fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) {
}
}
// Pack four fp8_e5_t values.
// Pack four fp8_e5_t values.
__forceinline__
__device__
fp8_e5_4_t
make_fp8_e5_4_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
TL_DEVICE
fp8_e5_4_t
make_fp8_e5_4_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x2
,
fp8_e5_t
x3
)
{
fp8_e5_t
x3
)
{
fp8_e5_4_t
result
;
fp8_e5_4_t
result
;
result
.
x
=
x0
;
result
.
x
=
x0
;
result
.
y
=
x1
;
result
.
y
=
x1
;
...
@@ -159,11 +156,9 @@ __forceinline__ __device__ fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1,
...
@@ -159,11 +156,9 @@ __forceinline__ __device__ fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1,
}
}
// Pack eight fp8_e5_t values.
// Pack eight fp8_e5_t values.
__forceinline__
__device__
fp8_e5_8_t
make_fp8_e5_8_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
TL_DEVICE
fp8_e5_8_t
make_fp8_e5_8_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x2
,
fp8_e5_t
x3
,
fp8_e5_t
x3
,
fp8_e5_t
x4
,
fp8_e5_t
x5
,
fp8_e5_t
x4
,
fp8_e5_t
x5
,
fp8_e5_t
x6
,
fp8_e5_t
x7
)
{
fp8_e5_t
x6
,
fp8_e5_t
x7
)
{
fp8_e5_8_t
result
;
fp8_e5_8_t
result
;
result
.
x
=
make_fp8_e5_4_t
(
x0
,
x1
,
x2
,
x3
);
result
.
x
=
make_fp8_e5_4_t
(
x0
,
x1
,
x2
,
x3
);
result
.
y
=
make_fp8_e5_4_t
(
x4
,
x5
,
x6
,
x7
);
result
.
y
=
make_fp8_e5_4_t
(
x4
,
x5
,
x6
,
x7
);
...
@@ -171,11 +166,12 @@ __forceinline__ __device__ fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1,
...
@@ -171,11 +166,12 @@ __forceinline__ __device__ fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1,
}
}
// Pack sixteen fp8_e5_t values.
// Pack sixteen fp8_e5_t values.
__forceinline__
__device__
fp8_e5_16_t
TL_DEVICE
fp8_e5_16_t
make_fp8_e5_16_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
make_fp8_e5_16_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x3
,
fp8_e5_t
x3
,
fp8_e5_t
x4
,
fp8_e5_t
x5
,
fp8_e5_t
x4
,
fp8_e5_t
x5
,
fp8_e5_t
x6
,
fp8_e5_t
x7
,
fp8_e5_t
x6
,
fp8_e5_t
x7
,
fp8_e5_t
y0
,
fp8_e5_t
y0
,
fp8_e5_t
y1
,
fp8_e5_t
y2
,
fp8_e5_t
y3
,
fp8_e5_t
y1
,
fp8_e5_t
y2
,
fp8_e5_t
y3
,
fp8_e5_t
y4
,
fp8_e5_t
y5
,
fp8_e5_t
y6
,
fp8_e5_t
y7
)
{
fp8_e5_t
y4
,
fp8_e5_t
y5
,
fp8_e5_t
y6
,
fp8_e5_t
y7
)
{
fp8_e5_16_t
result
;
fp8_e5_16_t
result
;
result
.
x
=
make_fp8_e5_8_t
(
x0
,
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
);
result
.
x
=
make_fp8_e5_8_t
(
x0
,
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
);
result
.
y
=
make_fp8_e5_8_t
(
y0
,
y1
,
y2
,
y3
,
y4
,
y5
,
y6
,
y7
);
result
.
y
=
make_fp8_e5_8_t
(
y0
,
y1
,
y2
,
y3
,
y4
,
y5
,
y6
,
y7
);
...
@@ -183,7 +179,7 @@ make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3,
...
@@ -183,7 +179,7 @@ make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3,
}
}
// Pack thirty-two fp8_e5_t values.
// Pack thirty-two fp8_e5_t values.
__forceinline__
__device__
fp8_e5_32_t
make_fp8_e5_32_t
(
TL_DEVICE
fp8_e5_32_t
make_fp8_e5_32_t
(
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x3
,
fp8_e5_t
x4
,
fp8_e5_t
x0
,
fp8_e5_t
x1
,
fp8_e5_t
x2
,
fp8_e5_t
x3
,
fp8_e5_t
x4
,
fp8_e5_t
x5
,
fp8_e5_t
x6
,
fp8_e5_t
x7
,
fp8_e5_t
x8
,
fp8_e5_t
x9
,
fp8_e5_t
x5
,
fp8_e5_t
x6
,
fp8_e5_t
x7
,
fp8_e5_t
x8
,
fp8_e5_t
x9
,
fp8_e5_t
x10
,
fp8_e5_t
x11
,
fp8_e5_t
x12
,
fp8_e5_t
x13
,
fp8_e5_t
x14
,
fp8_e5_t
x10
,
fp8_e5_t
x11
,
fp8_e5_t
x12
,
fp8_e5_t
x13
,
fp8_e5_t
x14
,
...
@@ -198,3 +194,14 @@ __forceinline__ __device__ fp8_e5_32_t make_fp8_e5_32_t(
...
@@ -198,3 +194,14 @@ __forceinline__ __device__ fp8_e5_32_t make_fp8_e5_32_t(
y12
,
y13
,
y14
,
y15
);
y12
,
y13
,
y14
,
y15
);
return
result
;
return
result
;
}
}
// e4m3x2 -> float2
TL_DEVICE
float2
__tl_cvt_fp8x2_to_float2
(
const
__nv_fp8x2_storage_t
x
,
const
__nv_fp8_interpretation_t
fp8_interpretation
)
{
half2
tmp
=
__nv_cvt_fp8x2_to_halfraw2
(
x
,
fp8_interpretation
);
float2
result
;
result
.
x
=
(
float
)
tmp
.
x
;
result
.
y
=
(
float
)
tmp
.
y
;
return
result
;
}
src/transform/layout_inference.cc
View file @
bcae814e
...
@@ -20,6 +20,7 @@
...
@@ -20,6 +20,7 @@
#include "../op/copy.h"
#include "../op/copy.h"
#include "../op/parallel.h"
#include "../op/parallel.h"
#include "../op/region.h"
#include "../op/region.h"
#include "../target/utils.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
...
@@ -1170,9 +1171,15 @@ private:
...
@@ -1170,9 +1171,15 @@ private:
// If a cast operation exists, vectorization may still be required
// If a cast operation exists, vectorization may still be required
bool
has_cast_operations
=
false
;
bool
has_cast_operations
=
false
;
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
PostOrderVisit
(
for_node
->
body
,
[
&
](
const
ObjectRef
&
obj
)
{
if
(
const
auto
*
st
ore
=
obj
.
as
<
BufferStore
Node
>
())
{
if
(
const
auto
*
ca
st
=
obj
.
as
<
Cast
Node
>
())
{
// Check if this is a non-reducer store with Cast operation
// Check if this is a non-reducer store with Cast operation
if
(
store
->
value
.
as
<
CastNode
>
())
{
DataType
src_type
=
cast
->
value
.
dtype
();
DataType
dst_type
=
cast
->
dtype
;
bool
src_ok
=
src_type
.
is_float
()
||
src_type
.
is_bfloat
()
||
src_type
.
is_float8_e4m3
()
||
src_type
.
is_float8_e5m2
();
bool
dst_ok
=
dst_type
.
is_float
()
||
dst_type
.
is_bfloat
()
||
dst_type
.
is_float8_e4m3
()
||
dst_type
.
is_float8_e5m2
();
if
(
src_ok
&&
dst_ok
&&
TargetIsCuda
(
Target
::
Current
()))
{
has_cast_operations
=
true
;
has_cast_operations
=
true
;
}
}
}
}
...
...
testing/python/language/test_tilelang_language_vectorized_cast.py
View file @
bcae814e
...
@@ -60,9 +60,10 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
...
@@ -60,9 +60,10 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
kernel
=
vectorized_cast_kernel
(
M
,
src_dtype_str
,
dst_dtype_str
)
kernel
=
vectorized_cast_kernel
(
M
,
src_dtype_str
,
dst_dtype_str
)
kernel_parallel
=
parallel_vectorized_cast_kernel
(
M
,
src_dtype_str
,
dst_dtype_str
)
kernel_parallel
=
parallel_vectorized_cast_kernel
(
M
,
src_dtype_str
,
dst_dtype_str
)
A
=
torch
.
randn
(
M
,
dtype
=
str2dtype
[
src_dtype_str
]).
cuda
()
A_float
=
torch
.
randn
(
M
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
B
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
]).
cuda
()
A
=
A_float
.
to
(
str2dtype
[
src_dtype_str
])
C
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
]).
cuda
()
B
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
],
device
=
"cuda"
)
C
=
torch
.
zeros
(
M
,
dtype
=
str2dtype
[
dst_dtype_str
],
device
=
"cuda"
)
kernel
(
A
,
B
)
kernel
(
A
,
B
)
kernel_parallel
(
A
,
C
)
kernel_parallel
(
A
,
C
)
...
@@ -101,6 +102,14 @@ def test_vectorized_cast():
...
@@ -101,6 +102,14 @@ def test_vectorized_cast():
run_vectorized_cast
(
"bfloat16"
,
"float32"
,
"__bfloat1622float2"
,
2
)
run_vectorized_cast
(
"bfloat16"
,
"float32"
,
"__bfloat1622float2"
,
2
)
run_vectorized_cast
(
"bfloat16"
,
"float32"
,
"__bfloat1622float2"
,
4
)
run_vectorized_cast
(
"bfloat16"
,
"float32"
,
"__bfloat1622float2"
,
4
)
# fp8_e4m3 -> fp32
run_vectorized_cast
(
"float8_e4m3"
,
"float32"
,
"__tl_cvt_fp8x2_to_float2"
,
2
)
run_vectorized_cast
(
"float8_e4m3"
,
"float32"
,
"__tl_cvt_fp8x2_to_float2"
,
4
)
# fp8_e5m2 -> fp32
run_vectorized_cast
(
"float8_e5m2"
,
"float32"
,
"__tl_cvt_fp8x2_to_float2"
,
2
)
run_vectorized_cast
(
"float8_e5m2"
,
"float32"
,
"__tl_cvt_fp8x2_to_float2"
,
4
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tilelang
.
testing
.
main
()
tilelang
.
testing
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment