Commit 8cc8db52 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Add tf32 casting to GEMM templates (#556)

* Add tf32 casting functionality to GEMM templates

- Introduced a `cast_float_to_tf32` function to convert float32 values to tfloat32 format across gemm_sm80, gemm_sm89, and gemm_sm90 templates.
- Implemented conditional casting in relevant sections of the GEMM operations to ensure compatibility with tfloat32 types.
- Enhanced the handling of tensor views to support the new casting logic, improving performance and accuracy in matrix operations.

* lint fix

* Refactor tfloat32 casting logic in GEMM templates

- Replaced the `is_tfloat32` boolean with `need_tfloat32_cast` to improve clarity and accuracy in determining when to cast float32 to tfloat32.
- Updated relevant sections in `gemm_sm80`, `gemm_sm89`, and `gemm_sm90` to utilize the new casting logic, enhancing compatibility with tfloat32 types.
- Ensured consistent application of casting across tensor views, improving performance and correctness in matrix operations.

* Refactor GEMM template functions for improved readability

- Simplified the function signature of `body_rs` in both `gemm_sm80` and `gemm_sm90` templates for better clarity.
- Adjusted the casting logic in `gemm_sm90` to ensure consistent application of `cast_float_to_tf32` across tensor views, enhancing performance and maintainability.

* Enhance tf32 casting logic in GEMM templates

- Updated the `cast_float_to_tf32` function in `gemm_sm80`, `gemm_sm89`, and `gemm_sm90` to conditionally apply the casting only if the input is finite, improving robustness.
- Simplified the `need_tfloat32_cast` logic to clarify the conditions under which tfloat32 casting is required, enhancing code readability and maintainability.

* Refactor GEMM template functions and layout inference logic

- Removed the `cast_float_to_tf32` function from `gemm_sm90` and updated the `body_sr` function to streamline the casting process for tensor views, enhancing code clarity and maintainability.
- Improved layout inference in `layout_inference.cc` by adding checks for the layout map's definition, ensuring robustness in handling layout annotations.
- Simplified the handling of layout maps in the `annotate_layout` function, allowing for more flexible layout definitions and error handling.
parent 24403aea
...@@ -199,6 +199,14 @@ struct OperandTraits<64, N, K, false, num_warp_n, ...@@ -199,6 +199,14 @@ struct OperandTraits<64, N, K, false, num_warp_n,
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
uint32_t x = reinterpret_cast<uint32_t const &>(a);
if (std::isfinite(a)) {
x += 0x1000u;
}
a = tfloat32_t::bitcast(x);
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw, bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw> typename B_type_raw, typename C_type_raw>
...@@ -211,6 +219,11 @@ public: ...@@ -211,6 +219,11 @@ public:
typename std::conditional<std::is_same<B_type_raw, float>::value, typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type; tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
std::is_same<B_type_raw, float>::value;
using Instruction = using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>; DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
...@@ -279,6 +292,10 @@ public: ...@@ -279,6 +292,10 @@ public:
for (int k = 0; k < size<2>(tCrA); ++k) { for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc); gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
} }
} }
...@@ -304,7 +321,9 @@ public: ...@@ -304,7 +321,9 @@ public:
Tensor tCrA = Tensor tCrA =
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)), make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{})); partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
}
if constexpr (clear_accum) { if constexpr (clear_accum) {
clear(acc); clear(acc);
} }
...@@ -315,6 +334,9 @@ public: ...@@ -315,6 +334,9 @@ public:
if (k < size<2>(tCrA) - 1) { if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
} }
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
} }
} }
...@@ -340,7 +362,9 @@ public: ...@@ -340,7 +362,9 @@ public:
Tensor tCrB = Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)), make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{})); partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
if constexpr (clear_accum) { if constexpr (clear_accum) {
clear(acc); clear(acc);
} }
...@@ -351,6 +375,9 @@ public: ...@@ -351,6 +375,9 @@ public:
if (k < size<2>(tCrA) - 1) { if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1)); copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
} }
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc); gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
} }
} }
......
...@@ -303,6 +303,14 @@ struct OperandTraits<64, N, K, false, num_warp_n, ...@@ -303,6 +303,14 @@ struct OperandTraits<64, N, K, false, num_warp_n,
using Copy = DefaultCopy; using Copy = DefaultCopy;
}; };
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
uint32_t x = reinterpret_cast<uint32_t const &>(a);
if (std::isfinite(a)) {
x += 0x1000u;
}
a = tfloat32_t::bitcast(x);
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A, template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw, bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw> typename B_type_raw, typename C_type_raw>
...@@ -315,6 +323,11 @@ public: ...@@ -315,6 +323,11 @@ public:
typename std::conditional<std::is_same<B_type_raw, float>::value, typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type; tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
std::is_same<B_type_raw, float>::value;
using Instruction = using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>; DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
...@@ -379,10 +392,19 @@ public: ...@@ -379,10 +392,19 @@ public:
// workaround // workaround
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
CUTE_UNROLL CUTE_UNROLL
for (int k = 0; k < size<2>(tCrA); ++k) { for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
// Convert float32 to tfloat32 because tfloat32 mma cannot truncate
// float32 automatically
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc); gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
} }
} }
...@@ -409,6 +431,10 @@ public: ...@@ -409,6 +431,10 @@ public:
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)), make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{})); partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
}
if constexpr (clear_accum) { if constexpr (clear_accum) {
clear(acc); clear(acc);
} }
...@@ -419,6 +445,11 @@ public: ...@@ -419,6 +445,11 @@ public:
if (k < size<2>(tCrA) - 1) { if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
} }
// Convert float32 to tfloat32 because tfloat32 mma cannot truncate
// float32 automatically
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
} }
} }
...@@ -444,7 +475,9 @@ public: ...@@ -444,7 +475,9 @@ public:
Tensor tCrB = Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)), make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{})); partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
if constexpr (clear_accum) { if constexpr (clear_accum) {
clear(acc); clear(acc);
} }
...@@ -455,6 +488,11 @@ public: ...@@ -455,6 +488,11 @@ public:
if (k < size<2>(tCrA) - 1) { if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1)); copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
} }
// Convert float32 to tfloat32 because tfloat32 mma cannot truncate
// float32 automatically
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc); gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
} }
} }
......
...@@ -13,6 +13,14 @@ namespace cute { ...@@ -13,6 +13,14 @@ namespace cute {
using namespace SM90; using namespace SM90;
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
uint32_t x = reinterpret_cast<uint32_t const &>(a);
if (std::isfinite(a)) {
x += 0x1000u;
}
a = tfloat32_t::bitcast(x);
};
namespace tl_wgmma { namespace tl_wgmma {
using namespace cutlass::gemm::collective::detail; // ss_smem_selector using namespace cutlass::gemm::collective::detail; // ss_smem_selector
...@@ -28,6 +36,12 @@ public: ...@@ -28,6 +36,12 @@ public:
tfloat32_t, B_type_raw>; tfloat32_t, B_type_raw>;
using C_type = C_type_raw; using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
// A_type will be tfloat32_t if A_type_raw is float
std::is_same<B_type_raw, float>::value;
// B_type will be tfloat32_t if B_type_raw is float
static constexpr GMMA::Major GmmaMajorA = static constexpr GMMA::Major GmmaMajorA =
trans_A ? GMMA::Major::MN : GMMA::Major::K; trans_A ? GMMA::Major::MN : GMMA::Major::K;
static constexpr GMMA::Major GmmaMajorB = static constexpr GMMA::Major GmmaMajorB =
...@@ -79,6 +93,10 @@ public: ...@@ -79,6 +93,10 @@ public:
if constexpr (clear_accum) { if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
} }
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// warpgroup_arrive(); // warpgroup_arrive();
...@@ -120,7 +138,10 @@ public: ...@@ -120,7 +138,10 @@ public:
Tensor acc = Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)), make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{})); partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
warpgroup_fence_operand(tCrA); warpgroup_fence_operand(tCrA);
warpgroup_fence_operand(acc); warpgroup_fence_operand(acc);
warpgroup_arrive(); warpgroup_arrive();
...@@ -140,16 +161,6 @@ public: ...@@ -140,16 +161,6 @@ public:
} }
warpgroup_fence_operand(acc); warpgroup_fence_operand(acc);
warpgroup_fence_operand(tCrA); warpgroup_fence_operand(tCrA);
// warpgroup_fence_operand(acc);
// warpgroup_arrive();
// gemm(tiled_mma, tCrA(_, _, _), tCrB(_, _, _), acc);
// warpgroup_commit_batch();
// if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
// warpgroup_fence_operand(acc);
} }
}; };
...@@ -361,6 +372,13 @@ public: ...@@ -361,6 +372,13 @@ public:
typename std::conditional<std::is_same<B_type_raw, float>::value, typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type; tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw; using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
std::is_same<A_type, tfloat32_t>::value &&
std::is_same<B_type_raw, float>::value &&
std::is_same<B_type, tfloat32_t>::value;
using Instruction = using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>; DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
...@@ -428,6 +446,10 @@ public: ...@@ -428,6 +446,10 @@ public:
for (int k = 0; k < size<2>(tCrA); ++k) { for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc); gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
} }
} }
...@@ -453,7 +475,9 @@ public: ...@@ -453,7 +475,9 @@ public:
Tensor tCrA = Tensor tCrA =
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)), make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{})); partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
}
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout())); auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
if constexpr (clear_accum) { if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
...@@ -464,6 +488,9 @@ public: ...@@ -464,6 +488,9 @@ public:
if (k < size<2>(tCrA) - 1) { if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1)); copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
} }
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc); gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
} }
} }
...@@ -489,7 +516,9 @@ public: ...@@ -489,7 +516,9 @@ public:
Tensor tCrB = Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)), make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{})); partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout())); auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
if constexpr (clear_accum) { if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
...@@ -500,6 +529,9 @@ public: ...@@ -500,6 +529,9 @@ public:
if (k < size<2>(tCrA) - 1) { if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1)); copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
} }
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc); gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
} }
} }
......
...@@ -491,9 +491,12 @@ private: ...@@ -491,9 +491,12 @@ private:
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
if (op->annotations.count(attr::kLayoutMap)) { if (op->annotations.count(attr::kLayoutMap)) {
auto map = // Check if the layout map is Map<Var, Layout>
op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>().value(); auto map = op->annotations.Get(attr::kLayoutMap).as<Map<Var, Layout>>();
for (const auto &[var, layout] : map) { ICHECK(map.defined()) << "layout map is not defined";
ICHECK(map.value().defined()) << "layout map is not defined";
for (const auto &[var, layout] : map.value()) {
ICHECK(buffer_data_to_buffer_.count(var)) ICHECK(buffer_data_to_buffer_.count(var))
<< "buffer " << var << " is not found in the block"; << "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var]; auto buffer = buffer_data_to_buffer_[var];
......
"""The language interface for tl programs.""" """The language interface for tl programs."""
from typing import Optional from typing import Optional, Callable, Dict
# from .parser import * # from .parser import *
# now is fully compatible with the upstream # now is fully compatible with the upstream
# tir script # tir script
...@@ -109,8 +109,16 @@ def annotate_layout(layout_map: Dict): ...@@ -109,8 +109,16 @@ def annotate_layout(layout_map: Dict):
return main return main
""" """
# layout_map is a dictionary of buffer to layout # layout_map is a dictionary of buffer to layout
layout_map = {buffer.data: layout for buffer, layout in layout_map.items()} _layout_map = {}
return block_attr({"layout_map": layout_map}) for buffer, layout in layout_map.items():
if isinstance(layout, Layout):
_layout_map[buffer.data] = layout
elif isinstance(layout, Callable):
_layout_map[buffer.data] = Layout(buffer.shape, layout)
else:
raise ValueError(f"Invalid layout: {layout}")
return block_attr({"layout_map": _layout_map})
def annotate_padding(padding_map: Dict): def annotate_padding(padding_map: Dict):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment