Commit 8cc8db52 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Add tf32 casting to GEMM templates (#556)

* Add tf32 casting functionality to GEMM templates

- Introduced a `cast_float_to_tf32` function to convert float32 values to tfloat32 format across gemm_sm80, gemm_sm89, and gemm_sm90 templates.
- Implemented conditional casting in relevant sections of the GEMM operations to ensure compatibility with tfloat32 types.
- Enhanced the handling of tensor views to support the new casting logic, improving performance and accuracy in matrix operations.

* lint fix

* Refactor tfloat32 casting logic in GEMM templates

- Replaced the `is_tfloat32` boolean with `need_tfloat32_cast` to improve clarity and accuracy in determining when to cast float32 to tfloat32.
- Updated relevant sections in `gemm_sm80`, `gemm_sm89`, and `gemm_sm90` to utilize the new casting logic, enhancing compatibility with tfloat32 types.
- Ensured consistent application of casting across tensor views, improving performance and correctness in matrix operations.

* Refactor GEMM template functions for improved readability

- Simplified the function signature of `body_rs` in both `gemm_sm80` and `gemm_sm90` templates for better clarity.
- Adjusted the casting logic in `gemm_sm90` to ensure consistent application of `cast_float_to_tf32` across tensor views, enhancing performance and maintainability.

* Enhance tf32 casting logic in GEMM templates

- Updated the `cast_float_to_tf32` function in `gemm_sm80`, `gemm_sm89`, and `gemm_sm90` to conditionally apply the casting only if the input is finite, improving robustness.
- Simplified the `need_tfloat32_cast` logic to clarify the conditions under which tfloat32 casting is required, enhancing code readability and maintainability.

* Refactor GEMM template functions and layout inference logic

- Removed the `cast_float_to_tf32` function from `gemm_sm90` and updated the `body_sr` function to streamline the casting process for tensor views, enhancing code clarity and maintainability.
- Improved layout inference in `layout_inference.cc` by adding checks for the layout map's definition, ensuring robustness in handling layout annotations.
- Simplified the handling of layout maps in the `annotate_layout` function, allowing for more flexible layout definitions and error handling.
parent 24403aea
......@@ -199,6 +199,14 @@ struct OperandTraits<64, N, K, false, num_warp_n,
using Copy = DefaultCopy;
};
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
uint32_t x = reinterpret_cast<uint32_t const &>(a);
if (std::isfinite(a)) {
x += 0x1000u;
}
a = tfloat32_t::bitcast(x);
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
......@@ -211,6 +219,11 @@ public:
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
std::is_same<B_type_raw, float>::value;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
......@@ -279,6 +292,10 @@ public:
for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -304,7 +321,9 @@ public:
Tensor tCrA =
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
}
if constexpr (clear_accum) {
clear(acc);
}
......@@ -315,6 +334,9 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
}
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -340,7 +362,9 @@ public:
Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
if constexpr (clear_accum) {
clear(acc);
}
......@@ -351,6 +375,9 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
}
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
}
}
......
......@@ -303,6 +303,14 @@ struct OperandTraits<64, N, K, false, num_warp_n,
using Copy = DefaultCopy;
};
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
uint32_t x = reinterpret_cast<uint32_t const &>(a);
if (std::isfinite(a)) {
x += 0x1000u;
}
a = tfloat32_t::bitcast(x);
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
......@@ -315,6 +323,11 @@ public:
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
std::is_same<B_type_raw, float>::value;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
......@@ -379,10 +392,19 @@ public:
// workaround
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
// Convert float32 to tfloat32 because tfloat32 mma cannot truncate
// float32 automatically
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -409,6 +431,10 @@ public:
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
}
if constexpr (clear_accum) {
clear(acc);
}
......@@ -419,6 +445,11 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
}
// Convert float32 to tfloat32 because tfloat32 mma cannot truncate
// float32 automatically
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -444,7 +475,9 @@ public:
Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
if constexpr (clear_accum) {
clear(acc);
}
......@@ -455,6 +488,11 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
}
// Convert float32 to tfloat32 because tfloat32 mma cannot truncate
// float32 automatically
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
}
}
......
......@@ -13,6 +13,14 @@ namespace cute {
using namespace SM90;
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
uint32_t x = reinterpret_cast<uint32_t const &>(a);
if (std::isfinite(a)) {
x += 0x1000u;
}
a = tfloat32_t::bitcast(x);
};
namespace tl_wgmma {
using namespace cutlass::gemm::collective::detail; // ss_smem_selector
......@@ -28,6 +36,12 @@ public:
tfloat32_t, B_type_raw>;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
// A_type will be tfloat32_t if A_type_raw is float
std::is_same<B_type_raw, float>::value;
// B_type will be tfloat32_t if B_type_raw is float
static constexpr GMMA::Major GmmaMajorA =
trans_A ? GMMA::Major::MN : GMMA::Major::K;
static constexpr GMMA::Major GmmaMajorB =
......@@ -79,6 +93,10 @@ public:
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// warpgroup_arrive();
......@@ -120,7 +138,10 @@ public:
Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
warpgroup_fence_operand(tCrA);
warpgroup_fence_operand(acc);
warpgroup_arrive();
......@@ -140,16 +161,6 @@ public:
}
warpgroup_fence_operand(acc);
warpgroup_fence_operand(tCrA);
// warpgroup_fence_operand(acc);
// warpgroup_arrive();
// gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);
// warpgroup_commit_batch();
// if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
// warpgroup_fence_operand(acc);
}
};
......@@ -361,6 +372,13 @@ public:
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
std::is_same<A_type, tfloat32_t>::value &&
std::is_same<B_type_raw, float>::value &&
std::is_same<B_type, tfloat32_t>::value;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
......@@ -428,6 +446,10 @@ public:
for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -453,7 +475,9 @@ public:
Tensor tCrA =
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
}
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
......@@ -464,6 +488,9 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
}
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -489,7 +516,9 @@ public:
Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
......@@ -500,6 +529,9 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
}
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
}
}
......
......@@ -491,9 +491,12 @@ private:
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
if (op->annotations.count(attr::kLayoutMap)) {
auto map =
op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value();
for (const auto &[var, layout] : map) {
// Check if the layout map is Map<Var, Layout>
auto map = op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>();
ICHECK(map.defined()) << "layout map is not defined";
ICHECK(map.value().defined()) << "layout map is not defined";
for (const auto &[var, layout] : map.value()) {
ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var];
......
"""The language interface for tl programs."""
from typing import Optional
from typing import Optional, Callable, Dict
# from .parser import *
# now is fully compatible with the upstream
# tir script
......@@ -109,8 +109,16 @@ def annotate_layout(layout_map: Dict):
return main
"""
# layout_map is a dictionary of buffer to layout
layout_map = {buffer.data: layout for buffer, layout in layout_map.items()}
return block_attr({"layout_map": layout_map})
_layout_map = {}
for buffer, layout in layout_map.items():
if isinstance(layout, Layout):
_layout_map[buffer.data] = layout
elif isinstance(layout, Callable):
_layout_map[buffer.data] = Layout(buffer.shape, layout)
else:
raise ValueError(f"Invalid layout: {layout}")
return block_attr({"layout_map": _layout_map})
def annotate_padding(padding_map: Dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment