Unverified Commit a148d62a authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Feature] Enhance vectorized conversion support in CUDA codegen (#1095)

* [Feature] Add vectorized float16 and float32 conversion support in CUDA codegen

* Implemented handling for conversions between float16 and float32 types, specifically for vectorized operations using __half22float2 and __float22half2_rn.
* Enhanced the existing code to support both directions of conversion based on the lane count.
* Improved overall type handling in the VisitExpr_ method for better compatibility with TileLang.

* [Feature] Add float32 to float8 conversion support in CUDA codegen

* Implemented handling for conversion from float32 to float8 (E4M3/E5M2) in the VisitExpr_ method.
* Added vectorized conversion support using __nv_cvt_float2_to_fp8x2 for float2 to fp8x2 transformations.
* Enhanced type handling for better compatibility with TileLang, particularly for float8 types.

* lint

* fix a bug

* [Enhancement] Support lanes=4 cases and add unit test for vectorized cast

* lint

* [Feature] Refactor bf16 convertion operations and remove legacy compile flags

* lint
parent 86c8bb46
......@@ -20,11 +20,9 @@ def get_bwd_configs():
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_fwd(
batch,
heads,
......@@ -140,11 +138,9 @@ def flashattn_fwd(
@tilelang.jit(
out_idx=[2],
pass_configs={
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
......@@ -180,11 +176,9 @@ def make_dq_layout(dQ):
@tilelang.jit(
out_idx=[1],
pass_configs={
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
......@@ -205,11 +199,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"
return flash_bwd_post
@tilelang.jit(
pass_configs={
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_bwd(batch,
heads,
seq_len,
......
......@@ -23,11 +23,9 @@ def get_configs():
rep=100,
)
@tilelang.jit(
out_idx=[3],
pass_configs={
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn(
batch,
heads,
......
......@@ -20,11 +20,9 @@ def get_bwd_configs():
@tilelang.jit(
out_idx=[3, 4],
pass_configs={
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_fwd(
batch,
heads,
......@@ -137,11 +135,9 @@ def flashattn_fwd(
@tilelang.jit(
out_idx=[2],
pass_configs={
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
......@@ -177,11 +173,9 @@ def make_dq_layout(dQ):
@tilelang.jit(
out_idx=[1],
pass_configs={
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
......@@ -202,11 +196,9 @@ def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"
return flash_bwd_post
@tilelang.jit(
pass_configs={
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_bwd(
batch,
heads,
......
......@@ -18,11 +18,9 @@ def get_configs():
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3],
pass_configs={
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn(
batch,
heads,
......
......@@ -19,11 +19,9 @@ def get_configs():
@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3],
pass_configs={
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn(
batch,
heads,
......
......@@ -900,56 +900,123 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
stream << ' ' << sret << ";\n";
std::string src = SSAGetID(PrintExpr(op->value), from_ty);
// Handle bfloat16 special cases with supported ops
bool used_bf16_op = false;
if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) {
std::ostringstream func_name;
if (from_ty.is_bfloat16()) {
func_name << "bf16";
} else if (from_ty.is_float()) {
func_name << "float";
}
if (from_ty.lanes() > 1) {
func_name << from_ty.lanes();
}
func_name << "2";
if (target_ty.is_bfloat16()) {
func_name << "bf16";
} else if (target_ty.is_float()) {
func_name << "float";
} else if (target_ty == DataType::Int(16)) {
func_name << "int16";
}
if (target_ty.lanes() > 1) {
func_name << target_ty.lanes();
}
auto fname = func_name.str();
if (bf16_supported_ops_.count(fname)) {
used_bf16_op = true;
stream << "#ifdef ENABLE_BF16\n";
// Handle conversion between float16 and float32
if (from_ty.is_float16() && target_ty.is_float()) {
// Use __half22float2 for vectorized conversion (half2 -> float2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// half2 -> float2
PrintIndent();
stream << "reinterpret_cast<";
if (target_ty.is_bfloat16()) {
stream << "__nv_bfloat16";
} else {
PrintType(target_ty.element_of(), stream);
stream << sret << " = __half22float2(*(half2*)(&(" << src << ")));\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// half4 -> float4
PrintIndent();
stream << "((float2*)(&" << sret << "))[0] = "
<< "__half22float2(*(half2*)(&(" << src << ")));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[1] = "
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
os << sret;
return;
}
if (target_ty.lanes() > 1) {
stream << target_ty.lanes();
} else if (from_ty.is_float() && target_ty.is_float16()) {
// Use __float22half2_rn for vectorized conversion (float2 -> half2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// float2 -> half2
PrintIndent();
stream << "*(half2*)(&(" << sret << ")) = __float22half2_rn(*(float2*)(&("
<< src << ")));\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// float4 -> half4
PrintIndent();
stream << "((half2*)(&" << sret << "))[0] = "
<< "__float22half2_rn(*(float2*)(&(" << src << ")));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[1] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
os << sret;
return;
}
stream << " &>(" << sret << ") = fastertransformer::" << fname
<< "(reinterpret_cast<";
if (from_ty.is_bfloat16()) {
stream << "__nv_bfloat16";
} else {
PrintType(from_ty.element_of(), stream);
}
if (from_ty.lanes() > 1) {
stream << from_ty.lanes();
// Handle conversion between bfloat16 and float32
if (from_ty.is_bfloat16() && target_ty.is_float()) {
// Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// bfloat162 -> float2
PrintIndent();
stream << sret
<< " = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
<< src << ")));\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// bfloat162x2 -> float4
PrintIndent();
stream << "((float2*)(&" << sret << "))[0] = "
<< "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
<< src << ")));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[1] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+1));\n";
os << sret;
return;
}
} else if (from_ty.is_float() && target_ty.is_bfloat16()) {
// Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// float2 -> bfloat162
PrintIndent();
stream << "*reinterpret_cast<__nv_bfloat162*>(&(" << sret
<< ")) = __float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// float4 -> bfloat162x2
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = "
<< "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
os << sret;
return;
}
stream << " const &>(" << src << "));\n";
stream << "#else\n";
}
// Handle conversion from float32 to float8 (E4M3/E5M2)
if (from_ty.is_float() &&
(target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) {
// FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion
// (float2 -> fp8x2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// float2 -> fp8x2
PrintIndent();
stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret
<< ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast<float2*>(&("
<< src << ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// float4 -> fp8x4
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
<< ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
}
}
......@@ -964,9 +1031,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
PrintVecElemStore(sret, target_ty, i, val.str());
}
if (used_bf16_op) {
stream << "#endif\n";
}
os << sret;
}
......
import torch
import tilelang.testing
import tilelang.language as T
str2dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float8_e4m3": torch.float8_e4m3fn,
"float8_e5m2": torch.float8_e5m2,
}
@tilelang.jit
def vectorized_cast_kernel(M: int, dtype_A: str, dtype_B: str):
assert M % 256 == 0
@T.prim_func
def main(
A: T.Tensor[(M), dtype_A], # noqa: F821
B: T.Tensor[(M), dtype_B], # noqa: F821
):
with T.Kernel(1, threads=128):
T.copy(A, B)
return main
def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, lanes: int = 2):
"""Run the vectorized cast kernel and check the correctness.
Args:
src_dtype_str: The source data type string.
dst_dtype_str: The destination data type string.
check_str: Used to ensure vectorized cast is used.
lanes: The number of lanes of the source and destination data types.
"""
M = 128 * lanes
kernel = vectorized_cast_kernel(M, src_dtype_str, dst_dtype_str)
A = torch.randn(M, dtype=str2dtype[src_dtype_str]).cuda()
B = torch.zeros(M, dtype=str2dtype[dst_dtype_str]).cuda()
kernel(A, B)
torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B)
code = kernel.get_kernel_source()
assert check_str in code, \
f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
def test_vectorized_cast():
# fp32 -> fp16
run_vectorized_cast("float32", "float16", "__float22half2_rn", 2)
run_vectorized_cast("float32", "float16", "__float22half2_rn", 4)
# fp16 -> fp32
run_vectorized_cast("float16", "float32", "__half22float2", 2)
run_vectorized_cast("float16", "float32", "__half22float2", 4)
# fp32 -> fp8_e4m3
run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2)
run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4)
# fp32 -> fp8_e5m2
run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2)
run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4)
# fp32 -> bf16
run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 2)
run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 4)
# bf16 -> fp32
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2)
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment