Commit 63b152d6 authored by danyao12's avatar danyao12
Browse files

Merge branch 'develop' into ck_tile/fa_bwd_v3

parents ae2d7d2b 14c3cfb1
...@@ -5,3 +5,4 @@ include_directories(AFTER ...@@ -5,3 +5,4 @@ include_directories(AFTER
add_subdirectory(01_fmha) add_subdirectory(01_fmha)
add_subdirectory(02_layernorm2d) add_subdirectory(02_layernorm2d)
add_subdirectory(03_gemm) add_subdirectory(03_gemm)
add_subdirectory(04_img2col)
...@@ -97,13 +97,6 @@ ...@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif #endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
// //
// CK kernels which support XDL (MI series) // CK kernels which support XDL (MI series)
// //
......
...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
......
...@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
} }
template <> template <>
__device__ static constexpr auto TailScheduler<1>() __device__ constexpr auto TailScheduler<1>()
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
...@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
} }
template <> template <>
__device__ static constexpr auto TailScheduler<2>() __device__ constexpr auto TailScheduler<2>()
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
......
...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "device_base.hpp" #include "device_base.hpp"
...@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator ...@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0; ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::size_t GetWorkspaceSize(index_t MRaw, virtual std::size_t GetWorkspaceSize(index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC) = 0; index_t StrideC) const = 0;
}; };
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
......
...@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[maybe_unused]] index_t K, [[maybe_unused]] index_t K,
[[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideA,
[[maybe_unused]] index_t StrideB, [[maybe_unused]] index_t StrideB,
index_t StrideC) override index_t StrideC) const override
{ {
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC); return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
} }
std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
{
const auto* parg = dynamic_cast<const Argument*>(base_arg);
if(!parg)
{
std::ostringstream err;
err << "Provided argument pointer is not of an Argument class!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return GetWorkspaceSize(
parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
}
}; };
} // namespace device } // namespace device
......
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
......
...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
skipped_group_count_++; skipped_group_count_++;
continue; continue;
......
...@@ -109,7 +109,7 @@ __global__ void ...@@ -109,7 +109,7 @@ __global__ void
N = gemm_desc_ptr[group_id].N; N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K; K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
grid_size_grp = 0; grid_size_grp = 0;
continue; continue;
......
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
......
...@@ -419,6 +419,12 @@ struct UnaryAbs ...@@ -419,6 +419,12 @@ struct UnaryAbs
y = ck::math::abs(x); y = ck::math::abs(x);
}; };
template <>
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const
{
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
};
}; };
struct UnarySqrt struct UnarySqrt
......
...@@ -324,55 +324,55 @@ struct DppSelector ...@@ -324,55 +324,55 @@ struct DppSelector
static constexpr auto GetDpp(); static constexpr auto GetDpp();
template <> template <>
static constexpr auto GetDpp<half_t, 8, 32>() constexpr auto GetDpp<half_t, 8, 32>()
{ {
return DppInstr::dpp8_f16_8x32x2; return DppInstr::dpp8_f16_8x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 8, 16>() constexpr auto GetDpp<half_t, 8, 16>()
{ {
return DppInstr::dpp8_f16_8x16x2; return DppInstr::dpp8_f16_8x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 16, 16>() constexpr auto GetDpp<half_t, 16, 16>()
{ {
return DppInstr::dpp8_f16_16x16x2; return DppInstr::dpp8_f16_16x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 32, 8>() constexpr auto GetDpp<half_t, 32, 8>()
{ {
return DppInstr::dpp8_f16_32x8x2; return DppInstr::dpp8_f16_32x8x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 1, 32>() constexpr auto GetDpp<half_t, 1, 32>()
{ {
return DppInstr::dpp8_f16_1x32x2; return DppInstr::dpp8_f16_1x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 2, 32>() constexpr auto GetDpp<half_t, 2, 32>()
{ {
return DppInstr::dpp8_f16_2x32x2; return DppInstr::dpp8_f16_2x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 2, 16>() constexpr auto GetDpp<half_t, 2, 16>()
{ {
return DppInstr::dpp8_f16_2x16x2; return DppInstr::dpp8_f16_2x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 4, 16>() constexpr auto GetDpp<half_t, 4, 16>()
{ {
return DppInstr::dpp8_f16_4x16x2; return DppInstr::dpp8_f16_4x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 4, 32>() constexpr auto GetDpp<half_t, 4, 32>()
{ {
return DppInstr::dpp8_f16_4x32x2; return DppInstr::dpp8_f16_4x32x2;
} }
......
...@@ -415,7 +415,7 @@ struct WmmaSelector ...@@ -415,7 +415,7 @@ struct WmmaSelector
static constexpr auto GetWmma(); static constexpr auto GetWmma();
template <> template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>() constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
...@@ -425,7 +425,7 @@ struct WmmaSelector ...@@ -425,7 +425,7 @@ struct WmmaSelector
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>() constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
...@@ -435,19 +435,19 @@ struct WmmaSelector ...@@ -435,19 +435,19 @@ struct WmmaSelector
} }
template <> template <>
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>() constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{ {
return WmmaInstr::wmma_f16_16x16x16_f16; return WmmaInstr::wmma_f16_16x16x16_f16;
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>() constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{ {
return WmmaInstr::wmma_bf16_16x16x16_bf16; return WmmaInstr::wmma_bf16_16x16x16_bf16;
} }
template <> template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>() constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
...@@ -458,7 +458,7 @@ struct WmmaSelector ...@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>() constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu4; return WmmaInstr::wmma_i32_16x16x16_iu4;
} }
......
...@@ -651,97 +651,97 @@ struct MfmaSelector ...@@ -651,97 +651,97 @@ struct MfmaSelector
static constexpr auto GetMfma(); static constexpr auto GetMfma();
template <> template <>
static constexpr auto GetMfma<double, 16, 16>() constexpr auto GetMfma<double, 16, 16>()
{ {
return MfmaInstr::mfma_f64_16x16x4f64; return MfmaInstr::mfma_f64_16x16x4f64;
} }
template <> template <>
static constexpr auto GetMfma<float, 64, 64>() constexpr auto GetMfma<float, 64, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x1xf32; return MfmaInstr::mfma_f32_32x32x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 32, 64>() constexpr auto GetMfma<float, 32, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x1xf32; return MfmaInstr::mfma_f32_32x32x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 16, 64>() constexpr auto GetMfma<float, 16, 64>()
{ {
return MfmaInstr::mfma_f32_16x16x1xf32; return MfmaInstr::mfma_f32_16x16x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 8, 64>() constexpr auto GetMfma<float, 8, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x1xf32; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 4, 64>() constexpr auto GetMfma<float, 4, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x1xf32; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 32, 32>() constexpr auto GetMfma<float, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x2xf32; return MfmaInstr::mfma_f32_32x32x2xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 16, 16>() constexpr auto GetMfma<float, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x4xf32; return MfmaInstr::mfma_f32_16x16x4xf32;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 64, 64>() constexpr auto GetMfma<half_t, 64, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x4f16; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 32, 64>() constexpr auto GetMfma<half_t, 32, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x4f16; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 32, 32>() constexpr auto GetMfma<half_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x8f16; return MfmaInstr::mfma_f32_32x32x8f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 16, 16>() constexpr auto GetMfma<half_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x16f16; return MfmaInstr::mfma_f32_16x16x16f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 16, 64>() constexpr auto GetMfma<half_t, 16, 64>()
{ {
return MfmaInstr::mfma_f32_16x16x4f16; return MfmaInstr::mfma_f32_16x16x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 8, 64>() constexpr auto GetMfma<half_t, 8, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 4, 64>() constexpr auto GetMfma<half_t, 4, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetMfma<bhalf_t, 32, 32>() constexpr auto GetMfma<bhalf_t, 32, 32>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k; return MfmaInstr::mfma_f32_32x32x8bf16_1k;
...@@ -751,7 +751,7 @@ struct MfmaSelector ...@@ -751,7 +751,7 @@ struct MfmaSelector
} }
template <> template <>
static constexpr auto GetMfma<bhalf_t, 16, 16>() constexpr auto GetMfma<bhalf_t, 16, 16>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k; return MfmaInstr::mfma_f32_16x16x16bf16_1k;
...@@ -762,72 +762,72 @@ struct MfmaSelector ...@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940) #if defined(CK_USE_AMD_MFMA_GFX940)
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x16i8; return MfmaInstr::mfma_i32_32x32x16i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x32i8; return MfmaInstr::mfma_i32_16x16x32i8;
} }
#else #else
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x8i8; return MfmaInstr::mfma_i32_32x32x8i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x16i8; return MfmaInstr::mfma_i32_16x16x16i8;
} }
#endif #endif
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() constexpr auto GetMfma<f8_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x16f8f8; return MfmaInstr::mfma_f32_32x32x16f8f8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 16, 16>() constexpr auto GetMfma<f8_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32>() constexpr auto GetMfma<bf8_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x16bf8bf8; return MfmaInstr::mfma_f32_32x32x16bf8bf8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 16, 16>() constexpr auto GetMfma<bf8_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x32bf8bf8; return MfmaInstr::mfma_f32_16x16x32bf8bf8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>() constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{ {
return MfmaInstr::mfma_f32_32x32x16f8bf8; return MfmaInstr::mfma_f32_32x32x16f8bf8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>() constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{ {
return MfmaInstr::mfma_f32_16x16x32f8bf8; return MfmaInstr::mfma_f32_16x16x32f8bf8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>() constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{ {
return MfmaInstr::mfma_f32_32x32x16bf8f8; return MfmaInstr::mfma_f32_32x32x16bf8f8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>() constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{ {
return MfmaInstr::mfma_f32_16x16x32bf8f8; return MfmaInstr::mfma_f32_16x16x32bf8f8;
} }
......
This diff is collapsed.
...@@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x) ...@@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x) static inline __host__ bool isnan(int4_t x)
{ {
...@@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x) ...@@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __device__ half_t sqrt(half_t x) static inline __device__ half_t sqrt(half_t x)
{ {
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
......
...@@ -157,8 +157,11 @@ ...@@ -157,8 +157,11 @@
#endif #endif
#endif #endif
// workaround for ROCm 6.2 and later
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE #ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133 #if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3)
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1 #define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else #else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0 #define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment