Unverified Commit feef9ef6 authored by LJC00118's avatar LJC00118 Committed by GitHub
Browse files

[Enhancement] Enhance Cast operations Vectorization (#1156)

* Enhance Cast vectorized

* Add Parallel vectorized cast test

* code lint

* merge newest commit
parent 198f22b3
...@@ -919,6 +919,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { ...@@ -919,6 +919,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n"; << "__half22float2(*((half2*)(&(" << src << "))+1));\n";
os << sret; os << sret;
return; return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// half8 -> float8
PrintIndent();
stream << "((float2*)(&" << sret << "))[0] = "
<< "__half22float2(*(half2*)(&(" << src << ")));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[1] = "
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[2] = "
<< "__half22float2(*((half2*)(&(" << src << "))+2));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[3] = "
<< "__half22float2(*((half2*)(&(" << src << "))+3));\n";
os << sret;
return;
} }
} else if (from_ty.is_float() && target_ty.is_float16()) { } else if (from_ty.is_float() && target_ty.is_float16()) {
// Use __float22half2_rn for vectorized conversion (float2 -> half2) // Use __float22half2_rn for vectorized conversion (float2 -> half2)
...@@ -939,6 +955,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { ...@@ -939,6 +955,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n"; << "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
os << sret; os << sret;
return; return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// float8 -> half8
PrintIndent();
stream << "((half2*)(&" << sret << "))[0] = "
<< "__float22half2_rn(*(float2*)(&(" << src << ")));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[1] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[2] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+2));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[3] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+3));\n";
os << sret;
return;
} }
} }
...@@ -965,6 +997,26 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { ...@@ -965,6 +997,26 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< src << "))+1));\n"; << src << "))+1));\n";
os << sret; os << sret;
return; return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// bfloat162x4 -> float8
PrintIndent();
stream << "((float2*)(&" << sret << "))[0] = "
<< "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
<< src << ")));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[1] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+1));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[2] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+2));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[3] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+3));\n";
os << sret;
return;
} }
} else if (from_ty.is_float() && target_ty.is_bfloat16()) { } else if (from_ty.is_float() && target_ty.is_bfloat16()) {
// Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162) // Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
...@@ -985,6 +1037,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { ...@@ -985,6 +1037,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n"; << "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
os << sret; os << sret;
return; return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// float8 -> bfloat162x4
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = "
<< "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[2] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+2));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[3] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+3));\n";
os << sret;
return;
} }
} }
...@@ -1019,6 +1087,34 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { ...@@ -1019,6 +1087,34 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< ");\n"; << ");\n";
os << sret; os << sret;
return; return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
// float8 -> fp8x8
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
<< ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[2] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+2), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[3] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+3), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
} }
} }
......
...@@ -597,7 +597,9 @@ private: ...@@ -597,7 +597,9 @@ private:
} }
} }
// Update the best plan if this one uses fewer registers // Update the best plan if this one uses fewer registers
if (reg_num < min_reg_num) { if (reg_num < min_reg_num ||
(reg_num == min_reg_num &&
attempt_infer_root < min_reg_num_infer_root)) {
best_infer_list = best_infer_list =
BackupInferList(); // Use backup to avoid moving out infer_list_ BackupInferList(); // Use backup to avoid moving out infer_list_
best_layout_map = tmp_layout_map; best_layout_map = tmp_layout_map;
...@@ -787,7 +789,18 @@ private: ...@@ -787,7 +789,18 @@ private:
} }
}); });
if (has_non_local && !has_reducer) { // If a cast operation exists, vectorization may still be required
bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if this is a non-reducer store with Cast operation
if (store->value.as<CastNode>()) {
has_cast_operations = true;
}
}
});
if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node); for_node = VectorizeLoop(for_node);
} }
......
...@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): ...@@ -17,8 +17,8 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor[(M), dtype_A], # noqa: F821 A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M), dtype_B], # noqa: F821 B: T.Tensor[(M,), dtype_B], # noqa: F821
): ):
with T.Kernel(1, threads=128): with T.Kernel(1, threads=128):
T.copy(A, B) T.copy(A, B)
...@@ -26,6 +26,27 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str): ...@@ -26,6 +26,27 @@ def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
return main return main
@tilelang.jit
def parallel_vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
assert M % 256 == 0
@T.prim_func
def main(
A: T.Tensor[(M,), dtype_A], # noqa: F821
B: T.Tensor[(M,), dtype_B], # noqa: F821
):
with T.Kernel(1, threads=128):
A_local = T.alloc_fragment((M,), dtype_A)
B_local = T.alloc_fragment((M,), dtype_B)
T.copy(A, A_local)
for i in T.Parallel(M):
B_local[i] = A_local[i]
T.copy(B_local, B)
return main
def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2): def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2):
"""Run the vectorized cast kernel and check the correctness. """Run the vectorized cast kernel and check the correctness.
Args: Args:
...@@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, ...@@ -37,17 +58,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
M = 128 * lanes M = 128 * lanes
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str) kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
kernel_parallel = parallel_vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda() A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda()
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda() B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
C = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
kernel(A, B) kernel(A, B)
kernel_parallel(A, C)
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B) torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B)
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), C)
code = kernel.get_kernel_source() code = kernel.get_kernel_source()
code_parallel = kernel_parallel.get_kernel_source()
assert check_str in code, \ assert check_str in code and check_str in code_parallel, \
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment