Unverified Commit bcae814e authored by Xiangwen Wang's avatar Xiangwen Wang Committed by GitHub
Browse files

Enhance vectorized conversion support (#1438)

parent e387102c
......@@ -1139,6 +1139,69 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}
}
if ((from_ty.is_float8_e4m3() || from_ty.is_float8_e5m2()) &&
target_ty.is_float()) {
// FP8 -> FP32: Use __tl_cvt_fp8x2_to_float2 for vectorized conversion
// (fp8x2 -> float2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// fp8x2 -> float2
PrintIndent();
stream << "*reinterpret_cast<float2*>(&(" << sret
<< ")) = "
"__tl_cvt_fp8x2_to_float2(*reinterpret_cast<__nv_fp8x2_storage_"
"t*>(&("
<< src << ")), "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// fp8x4 -> float4
PrintIndent();
stream << "*(float2*)(&" << sret << ") = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[0], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+1) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[1], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// fp8x8 -> float8
PrintIndent();
stream << "*(float2*)(&" << sret << ") = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[0], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+1) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[1], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+2) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[2], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+3) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[3], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
}
}
// Fallback: elementwise cast
for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
std::ostringstream val;
......
......@@ -33,7 +33,7 @@ struct __CUDA_ALIGN__(32) fp8_e4_32_t {
fp8_e4_16_t x;
fp8_e4_16_t y;
__device__ __forceinline__ fp8_e4_32_t &operator=(const ulonglong4 &rhs) {
TL_DEVICE fp8_e4_32_t &operator=(const ulonglong4 &rhs) {
x.x = *(fp8_e4_8_t *)&rhs.x;
x.y = *(fp8_e4_8_t *)&rhs.y;
y.x = *(fp8_e4_8_t *)&rhs.z;
......@@ -68,7 +68,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
fp8_e5_16_t x;
fp8_e5_16_t y;
__device__ __forceinline__ fp8_e5_32_t &operator=(const ulonglong4 &rhs) {
TL_DEVICE fp8_e5_32_t &operator=(const ulonglong4 &rhs) {
x.x = *(fp8_e5_8_t *)&rhs.x;
x.y = *(fp8_e5_8_t *)&rhs.y;
y.x = *(fp8_e5_8_t *)&rhs.z;
......@@ -78,7 +78,7 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
};
// Pack two fp8_e4_t values.
__forceinline__ __device__ fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) {
TL_DEVICE fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) {
fp8_e4_2_t result;
result.x = x;
result.y = y;
......@@ -86,9 +86,8 @@ __forceinline__ __device__ fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) {
}
// Pack four fp8_e4_t values.
__forceinline__ __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1,
fp8_e4_t x2,
fp8_e4_t x3) {
TL_DEVICE fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2,
fp8_e4_t x3) {
fp8_e4_4_t result;
result.x = x0;
result.y = x1;
......@@ -98,11 +97,9 @@ __forceinline__ __device__ fp8_e4_4_t make_fp8_e4_4_t(fp8_e4_t x0, fp8_e4_t x1,
}
// Pack eight fp8_e4_t values.
__forceinline__ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1,
fp8_e4_t x2, fp8_e4_t x3,
fp8_e4_t x4, fp8_e4_t x5,
fp8_e4_t x6,
fp8_e4_t x7) {
TL_DEVICE fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2,
fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5,
fp8_e4_t x6, fp8_e4_t x7) {
fp8_e4_8_t result;
result.x = make_fp8_e4_4_t(x0, x1, x2, x3);
result.y = make_fp8_e4_4_t(x4, x5, x6, x7);
......@@ -110,11 +107,12 @@ __forceinline__ __device__ fp8_e4_8_t make_fp8_e4_8_t(fp8_e4_t x0, fp8_e4_t x1,
}
// Pack sixteen fp8_e4_t values.
__forceinline__ __device__ fp8_e4_16_t
make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3,
fp8_e4_t x4, fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7,
fp8_e4_t y0, fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3,
fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6, fp8_e4_t y7) {
TL_DEVICE fp8_e4_16_t make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2,
fp8_e4_t x3, fp8_e4_t x4, fp8_e4_t x5,
fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t y0,
fp8_e4_t y1, fp8_e4_t y2, fp8_e4_t y3,
fp8_e4_t y4, fp8_e4_t y5, fp8_e4_t y6,
fp8_e4_t y7) {
fp8_e4_16_t result;
result.x = make_fp8_e4_8_t(x0, x1, x2, x3, x4, x5, x6, x7);
result.y = make_fp8_e4_8_t(y0, y1, y2, y3, y4, y5, y6, y7);
......@@ -122,7 +120,7 @@ make_fp8_e4_16_t(fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3,
}
// Pack thirty-two fp8_e4_t values.
__forceinline__ __device__ fp8_e4_32_t make_fp8_e4_32_t(
TL_DEVICE fp8_e4_32_t make_fp8_e4_32_t(
fp8_e4_t x0, fp8_e4_t x1, fp8_e4_t x2, fp8_e4_t x3, fp8_e4_t x4,
fp8_e4_t x5, fp8_e4_t x6, fp8_e4_t x7, fp8_e4_t x8, fp8_e4_t x9,
fp8_e4_t x10, fp8_e4_t x11, fp8_e4_t x12, fp8_e4_t x13, fp8_e4_t x14,
......@@ -139,7 +137,7 @@ __forceinline__ __device__ fp8_e4_32_t make_fp8_e4_32_t(
}
// Pack two fp8_e5_t values.
__forceinline__ __device__ fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) {
TL_DEVICE fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) {
fp8_e5_2_t result;
result.x = x;
result.y = y;
......@@ -147,9 +145,8 @@ __forceinline__ __device__ fp8_e5_2_t make_fp8_e5_2_t(fp8_e5_t x, fp8_e5_t y) {
}
// Pack four fp8_e5_t values.
__forceinline__ __device__ fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1,
fp8_e5_t x2,
fp8_e5_t x3) {
TL_DEVICE fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2,
fp8_e5_t x3) {
fp8_e5_4_t result;
result.x = x0;
result.y = x1;
......@@ -159,11 +156,9 @@ __forceinline__ __device__ fp8_e5_4_t make_fp8_e5_4_t(fp8_e5_t x0, fp8_e5_t x1,
}
// Pack eight fp8_e5_t values.
__forceinline__ __device__ fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1,
fp8_e5_t x2, fp8_e5_t x3,
fp8_e5_t x4, fp8_e5_t x5,
fp8_e5_t x6,
fp8_e5_t x7) {
TL_DEVICE fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2,
fp8_e5_t x3, fp8_e5_t x4, fp8_e5_t x5,
fp8_e5_t x6, fp8_e5_t x7) {
fp8_e5_8_t result;
result.x = make_fp8_e5_4_t(x0, x1, x2, x3);
result.y = make_fp8_e5_4_t(x4, x5, x6, x7);
......@@ -171,11 +166,12 @@ __forceinline__ __device__ fp8_e5_8_t make_fp8_e5_8_t(fp8_e5_t x0, fp8_e5_t x1,
}
// Pack sixteen fp8_e5_t values.
__forceinline__ __device__ fp8_e5_16_t
make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3,
fp8_e5_t x4, fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7,
fp8_e5_t y0, fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3,
fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6, fp8_e5_t y7) {
TL_DEVICE fp8_e5_16_t make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2,
fp8_e5_t x3, fp8_e5_t x4, fp8_e5_t x5,
fp8_e5_t x6, fp8_e5_t x7, fp8_e5_t y0,
fp8_e5_t y1, fp8_e5_t y2, fp8_e5_t y3,
fp8_e5_t y4, fp8_e5_t y5, fp8_e5_t y6,
fp8_e5_t y7) {
fp8_e5_16_t result;
result.x = make_fp8_e5_8_t(x0, x1, x2, x3, x4, x5, x6, x7);
result.y = make_fp8_e5_8_t(y0, y1, y2, y3, y4, y5, y6, y7);
......@@ -183,7 +179,7 @@ make_fp8_e5_16_t(fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3,
}
// Pack thirty-two fp8_e5_t values.
__forceinline__ __device__ fp8_e5_32_t make_fp8_e5_32_t(
TL_DEVICE fp8_e5_32_t make_fp8_e5_32_t(
fp8_e5_t x0, fp8_e5_t x1, fp8_e5_t x2, fp8_e5_t x3, fp8_e5_t x4,
fp8_e5_t x5, fp8_e5_t x6, fp8_e5_t x7, fp8_e5_t x8, fp8_e5_t x9,
fp8_e5_t x10, fp8_e5_t x11, fp8_e5_t x12, fp8_e5_t x13, fp8_e5_t x14,
......@@ -198,3 +194,14 @@ __forceinline__ __device__ fp8_e5_32_t make_fp8_e5_32_t(
y12, y13, y14, y15);
return result;
}
// e4m3x2 -> float2
TL_DEVICE float2
__tl_cvt_fp8x2_to_float2(const __nv_fp8x2_storage_t x,
const __nv_fp8_interpretation_t fp8_interpretation) {
half2 tmp = __nv_cvt_fp8x2_to_halfraw2(x, fp8_interpretation);
float2 result;
result.x = (float)tmp.x;
result.y = (float)tmp.y;
return result;
}
......@@ -20,6 +20,7 @@
#include "../op/copy.h"
#include "../op/parallel.h"
#include "../op/region.h"
#include "../target/utils.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
......@@ -1170,9 +1171,15 @@ private:
// If a cast operation exists, vectorization may still be required
bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (const auto *cast = obj.as<CastNode>()) {
// Check if this is a non-reducer store with Cast operation
if (store->value.as<CastNode>()) {
DataType src_type = cast->value.dtype();
DataType dst_type = cast->dtype;
bool src_ok = src_type.is_float() || src_type.is_bfloat() ||
src_type.is_float8_e4m3() || src_type.is_float8_e5m2();
bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() ||
dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2();
if (src_ok && dst_ok && TargetIsCuda(Target::Current())) {
has_cast_operations = true;
}
}
......
......@@ -60,9 +60,10 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda()
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
A_float = torch.randn(M, dtype=torch.float32, device="cuda")
A = A_float.to(str2dtype[src_dtype_str])
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda")
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str], device="cuda")
kernel(A, B)
kernel_parallel(A, C)
......@@ -101,6 +102,14 @@ def test_vectorized_cast():
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2)
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4)
# fp8_e4m3 -> fp32
run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2)
run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4)
# fp8_e5m2 -> fp32
run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2)
run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment