Commit 9ba8b480 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Phaseout tf32 Casting from GEMM Templates (#573)

* [Feature] Add Quarter Bank Swizzle Layout and Update GEMM Layout Logic

- Introduced a new `makeQuarterBankSwizzleLayout` function for layout swizzling of 32 bytes.
- Updated `makeGemmABLayout` to include an `enable_padding` parameter, allowing for conditional layout selection between padded and quarter bank swizzle layouts.
- Adjusted layout inference in GEMM operations to utilize the new quarter bank swizzle layout when appropriate.
- Enhanced bulk copy operations to recognize and handle the new layout type, improving memory access patterns.

* lint fix

* [Refactor] Update GEMM Layout Functions and Inference Logic

- Removed the `enable_padding` parameter from `makeGemmABLayout` to simplify its signature.
- Introduced `makeGemmABLayoutHopper` for enhanced layout handling specific to Hopper architecture.
- Updated layout inference in GEMM operations to utilize the new `makeGemmABLayoutHopper` function, improving clarity and maintainability in layout selection.
- Adjusted related layout functions to ensure consistent behavior across different architectures.

* [Refactor] Remove tf32 Casting Logic from GEMM Templates

- Eliminated the `cast_float_to_tf32` function from `gemm_sm80`, `gemm_sm89`, and `gemm_sm90` templates to streamline the code.
- Removed conditional casting logic for float32 to tfloat32 conversion, enhancing clarity and maintainability.
- Updated relevant sections in GEMM operations to reflect the removal of casting, ensuring consistent behavior across templates.
- Adjusted tensor view handling to improve performance and accuracy in matrix operations.

* Update bulk_copy.cc

* Fix profiler initialization in GEMM test by removing TensorSupplyType argument for improved flexibility.
parent 61ee0bec
......@@ -199,14 +199,6 @@ struct OperandTraits<64, N, K, false, num_warp_n,
using Copy = DefaultCopy;
};
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
uint32_t x = reinterpret_cast<uint32_t const &>(a);
if (std::isfinite(a)) {
x += 0x1000u;
}
a = tfloat32_t::bitcast(x);
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
......@@ -220,10 +212,6 @@ public:
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
std::is_same<B_type_raw, float>::value;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
......@@ -292,10 +280,6 @@ public:
for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -321,9 +305,6 @@ public:
Tensor tCrA =
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
}
if constexpr (clear_accum) {
clear(acc);
}
......@@ -334,9 +315,6 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
}
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -362,9 +340,6 @@ public:
Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
if constexpr (clear_accum) {
clear(acc);
}
......@@ -375,9 +350,6 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
}
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
}
}
......
......@@ -303,14 +303,6 @@ struct OperandTraits<64, N, K, false, num_warp_n,
using Copy = DefaultCopy;
};
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
uint32_t x = reinterpret_cast<uint32_t const &>(a);
if (std::isfinite(a)) {
x += 0x1000u;
}
a = tfloat32_t::bitcast(x);
};
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum, typename A_type_raw,
typename B_type_raw, typename C_type_raw>
......@@ -324,10 +316,6 @@ public:
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
std::is_same<B_type_raw, float>::value;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
......@@ -397,14 +385,6 @@ public:
for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
// Convert float32 to tfloat32 because tfloat32 mma cannot truncate
// float32 automatically
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -431,10 +411,6 @@ public:
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
}
if constexpr (clear_accum) {
clear(acc);
}
......@@ -445,11 +421,6 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
}
// Convert float32 to tfloat32 because tfloat32 mma cannot truncate
// float32 automatically
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -475,9 +446,6 @@ public:
Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
if constexpr (clear_accum) {
clear(acc);
}
......@@ -488,11 +456,6 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
}
// Convert float32 to tfloat32 because tfloat32 mma cannot truncate
// float32 automatically
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
}
}
......
......@@ -13,14 +13,6 @@ namespace cute {
using namespace SM90;
template <typename T> CUTE_HOST_DEVICE static void cast_float_to_tf32(T &a) {
uint32_t x = reinterpret_cast<uint32_t const &>(a);
if (std::isfinite(a)) {
x += 0x1000u;
}
a = tfloat32_t::bitcast(x);
};
namespace tl_wgmma {
using namespace cutlass::gemm::collective::detail; // ss_smem_selector
......@@ -36,12 +28,6 @@ public:
tfloat32_t, B_type_raw>;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
// A_type will be tfloat32_t if A_type_raw is float
std::is_same<B_type_raw, float>::value;
// B_type will be tfloat32_t if B_type_raw is float
static constexpr GMMA::Major GmmaMajorA =
trans_A ? GMMA::Major::MN : GMMA::Major::K;
static constexpr GMMA::Major GmmaMajorB =
......@@ -93,10 +79,6 @@ public:
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
}
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
// warpgroup_arrive();
......@@ -138,10 +120,7 @@ public:
Tensor acc =
make_tensor(make_rmem_ptr(reinterpret_cast<C_type *>(pC)),
partition_shape_C(tiled_mma, Shape<Int<M>, Int<N>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
warpgroup_fence_operand(tCrA);
warpgroup_fence_operand(acc);
warpgroup_arrive();
......@@ -373,12 +352,6 @@ public:
tfloat32_t, A_type_raw>::type;
using C_type = C_type_raw;
static constexpr bool need_tfloat32_cast =
std::is_same<A_type_raw, float>::value &&
std::is_same<A_type, tfloat32_t>::value &&
std::is_same<B_type_raw, float>::value &&
std::is_same<B_type, tfloat32_t>::value;
using Instruction =
DispatchInstruction<A_type, B_type, C_type, num_warp_m, num_warp_n, N>;
......@@ -446,10 +419,6 @@ public:
for (int k = 0; k < size<2>(tCrA); ++k) {
copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k));
copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -475,9 +444,6 @@ public:
Tensor tCrA =
make_tensor(make_rmem_ptr(reinterpret_cast<A_type *>(pA)),
partition_shape_A(tiled_mma, Shape<Int<M>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA, cast_float_to_tf32<A_type>);
}
auto tCrB_view = make_tensor(tCrB.data(), remove_swizzle(tCrB.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
......@@ -488,9 +454,6 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_B, tCsB(_, _, k + 1), tCrB_copy_view(_, _, k + 1));
}
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB_view(_, _, k), cast_float_to_tf32<B_type>);
}
gemm(tiled_mma, tCrA(_, _, k), tCrB_view(_, _, k), acc);
}
}
......@@ -516,9 +479,6 @@ public:
Tensor tCrB =
make_tensor(make_rmem_ptr(reinterpret_cast<B_type *>(pB)),
partition_shape_B(tiled_mma, Shape<Int<N>, Int<K>>{}));
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrB, cast_float_to_tf32<B_type>);
}
auto tCrA_view = make_tensor(tCrA.data(), remove_swizzle(tCrA.layout()));
if constexpr (clear_accum) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
......@@ -529,9 +489,6 @@ public:
if (k < size<2>(tCrA) - 1) {
copy(tiled_copy_A, tCsA(_, _, k + 1), tCrA_copy_view(_, _, k + 1));
}
if constexpr (need_tfloat32_cast) {
cute::for_each(tCrA_view(_, _, k), cast_float_to_tf32<A_type>);
}
gemm(tiled_mma, tCrA_view(_, _, k), tCrB(_, _, k), acc);
}
}
......
......@@ -91,6 +91,11 @@ def run_gemm(
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......@@ -383,7 +388,9 @@ def run_gemm_sr(
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
A = A.to(torch.float)
B = B.to(torch.float)
C = torch.matmul(A, B)
C = C.to(torch.__getattribute__(out_dtype))
return C
......
......@@ -424,6 +424,8 @@ class AutoTuner:
logger.debug(f"Error: {e}")
continue
logging.debug(f"Config {config} latency: {latency} at index {i}")
if latency < best_latency:
best_latency = latency
best_config = config
......
......@@ -6,4 +6,5 @@ from .cuda_driver import (
get_max_dynamic_shared_size_bytes, # noqa: F401
get_persisting_l2_cache_max_size, # noqa: F401
get_num_sms, # noqa: F401
get_registers_per_block, # noqa: F401
)
......@@ -190,3 +190,11 @@ def get_num_sms(device_id: int = 0) -> int:
return prop.multiProcessorCount
else:
raise RuntimeError("Failed to get device properties.")
def get_registers_per_block(device_id: int = 0) -> int:
prop = get_cuda_device_properties(device_id)
if prop:
return prop.regsPerBlock
else:
raise RuntimeError("Failed to get device properties.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment