"test/vscode:/vscode.git/clone" did not exist on "4ef704d8a6741ec563439053b7cae17240c73438"
Commit 63b152d6 authored by danyao12's avatar danyao12
Browse files

Merge branch 'develop' into ck_tile/fa_bwd_v3

parents ae2d7d2b 14c3cfb1
...@@ -5,3 +5,4 @@ include_directories(AFTER ...@@ -5,3 +5,4 @@ include_directories(AFTER
add_subdirectory(01_fmha) add_subdirectory(01_fmha)
add_subdirectory(02_layernorm2d) add_subdirectory(02_layernorm2d)
add_subdirectory(03_gemm) add_subdirectory(03_gemm)
add_subdirectory(04_img2col)
...@@ -97,13 +97,6 @@ ...@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif #endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
// //
// CK kernels which support XDL (MI series) // CK kernels which support XDL (MI series)
// //
......
...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
......
...@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
} }
template <> template <>
__device__ static constexpr auto TailScheduler<1>() __device__ constexpr auto TailScheduler<1>()
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
...@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 ...@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
} }
template <> template <>
__device__ static constexpr auto TailScheduler<2>() __device__ constexpr auto TailScheduler<2>()
{ {
// schedule // schedule
constexpr auto num_ds_read_inst = constexpr auto num_ds_read_inst =
......
...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "device_base.hpp" #include "device_base.hpp"
...@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator ...@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0; ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::size_t GetWorkspaceSize(index_t MRaw, virtual std::size_t GetWorkspaceSize(index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC) = 0; index_t StrideC) const = 0;
}; };
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
......
...@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[maybe_unused]] index_t K, [[maybe_unused]] index_t K,
[[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideA,
[[maybe_unused]] index_t StrideB, [[maybe_unused]] index_t StrideB,
index_t StrideC) override index_t StrideC) const override
{ {
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC); return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
} }
std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
{
const auto* parg = dynamic_cast<const Argument*>(base_arg);
if(!parg)
{
std::ostringstream err;
err << "Provided argument pointer is not of an Argument class!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return GetWorkspaceSize(
parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
}
}; };
} // namespace device } // namespace device
......
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
......
...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
skipped_group_count_++; skipped_group_count_++;
continue; continue;
......
...@@ -109,7 +109,7 @@ __global__ void ...@@ -109,7 +109,7 @@ __global__ void
N = gemm_desc_ptr[group_id].N; N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K; K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
grid_size_grp = 0; grid_size_grp = 0;
continue; continue;
......
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
......
...@@ -419,6 +419,12 @@ struct UnaryAbs ...@@ -419,6 +419,12 @@ struct UnaryAbs
y = ck::math::abs(x); y = ck::math::abs(x);
}; };
template <>
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const
{
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
};
}; };
struct UnarySqrt struct UnarySqrt
......
...@@ -324,55 +324,55 @@ struct DppSelector ...@@ -324,55 +324,55 @@ struct DppSelector
static constexpr auto GetDpp(); static constexpr auto GetDpp();
template <> template <>
static constexpr auto GetDpp<half_t, 8, 32>() constexpr auto GetDpp<half_t, 8, 32>()
{ {
return DppInstr::dpp8_f16_8x32x2; return DppInstr::dpp8_f16_8x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 8, 16>() constexpr auto GetDpp<half_t, 8, 16>()
{ {
return DppInstr::dpp8_f16_8x16x2; return DppInstr::dpp8_f16_8x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 16, 16>() constexpr auto GetDpp<half_t, 16, 16>()
{ {
return DppInstr::dpp8_f16_16x16x2; return DppInstr::dpp8_f16_16x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 32, 8>() constexpr auto GetDpp<half_t, 32, 8>()
{ {
return DppInstr::dpp8_f16_32x8x2; return DppInstr::dpp8_f16_32x8x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 1, 32>() constexpr auto GetDpp<half_t, 1, 32>()
{ {
return DppInstr::dpp8_f16_1x32x2; return DppInstr::dpp8_f16_1x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 2, 32>() constexpr auto GetDpp<half_t, 2, 32>()
{ {
return DppInstr::dpp8_f16_2x32x2; return DppInstr::dpp8_f16_2x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 2, 16>() constexpr auto GetDpp<half_t, 2, 16>()
{ {
return DppInstr::dpp8_f16_2x16x2; return DppInstr::dpp8_f16_2x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 4, 16>() constexpr auto GetDpp<half_t, 4, 16>()
{ {
return DppInstr::dpp8_f16_4x16x2; return DppInstr::dpp8_f16_4x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 4, 32>() constexpr auto GetDpp<half_t, 4, 32>()
{ {
return DppInstr::dpp8_f16_4x32x2; return DppInstr::dpp8_f16_4x32x2;
} }
......
...@@ -415,7 +415,7 @@ struct WmmaSelector ...@@ -415,7 +415,7 @@ struct WmmaSelector
static constexpr auto GetWmma(); static constexpr auto GetWmma();
template <> template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>() constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
...@@ -425,7 +425,7 @@ struct WmmaSelector ...@@ -425,7 +425,7 @@ struct WmmaSelector
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>() constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
...@@ -435,19 +435,19 @@ struct WmmaSelector ...@@ -435,19 +435,19 @@ struct WmmaSelector
} }
template <> template <>
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>() constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{ {
return WmmaInstr::wmma_f16_16x16x16_f16; return WmmaInstr::wmma_f16_16x16x16_f16;
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>() constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{ {
return WmmaInstr::wmma_bf16_16x16x16_bf16; return WmmaInstr::wmma_bf16_16x16x16_bf16;
} }
template <> template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>() constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
...@@ -458,7 +458,7 @@ struct WmmaSelector ...@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>() constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu4; return WmmaInstr::wmma_i32_16x16x16_iu4;
} }
......
...@@ -651,97 +651,97 @@ struct MfmaSelector ...@@ -651,97 +651,97 @@ struct MfmaSelector
static constexpr auto GetMfma(); static constexpr auto GetMfma();
template <> template <>
static constexpr auto GetMfma<double, 16, 16>() constexpr auto GetMfma<double, 16, 16>()
{ {
return MfmaInstr::mfma_f64_16x16x4f64; return MfmaInstr::mfma_f64_16x16x4f64;
} }
template <> template <>
static constexpr auto GetMfma<float, 64, 64>() constexpr auto GetMfma<float, 64, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x1xf32; return MfmaInstr::mfma_f32_32x32x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 32, 64>() constexpr auto GetMfma<float, 32, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x1xf32; return MfmaInstr::mfma_f32_32x32x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 16, 64>() constexpr auto GetMfma<float, 16, 64>()
{ {
return MfmaInstr::mfma_f32_16x16x1xf32; return MfmaInstr::mfma_f32_16x16x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 8, 64>() constexpr auto GetMfma<float, 8, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x1xf32; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 4, 64>() constexpr auto GetMfma<float, 4, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x1xf32; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 32, 32>() constexpr auto GetMfma<float, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x2xf32; return MfmaInstr::mfma_f32_32x32x2xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 16, 16>() constexpr auto GetMfma<float, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x4xf32; return MfmaInstr::mfma_f32_16x16x4xf32;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 64, 64>() constexpr auto GetMfma<half_t, 64, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x4f16; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 32, 64>() constexpr auto GetMfma<half_t, 32, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x4f16; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 32, 32>() constexpr auto GetMfma<half_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x8f16; return MfmaInstr::mfma_f32_32x32x8f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 16, 16>() constexpr auto GetMfma<half_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x16f16; return MfmaInstr::mfma_f32_16x16x16f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 16, 64>() constexpr auto GetMfma<half_t, 16, 64>()
{ {
return MfmaInstr::mfma_f32_16x16x4f16; return MfmaInstr::mfma_f32_16x16x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 8, 64>() constexpr auto GetMfma<half_t, 8, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 4, 64>() constexpr auto GetMfma<half_t, 4, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetMfma<bhalf_t, 32, 32>() constexpr auto GetMfma<bhalf_t, 32, 32>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k; return MfmaInstr::mfma_f32_32x32x8bf16_1k;
...@@ -751,7 +751,7 @@ struct MfmaSelector ...@@ -751,7 +751,7 @@ struct MfmaSelector
} }
template <> template <>
static constexpr auto GetMfma<bhalf_t, 16, 16>() constexpr auto GetMfma<bhalf_t, 16, 16>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k; return MfmaInstr::mfma_f32_16x16x16bf16_1k;
...@@ -762,72 +762,72 @@ struct MfmaSelector ...@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940) #if defined(CK_USE_AMD_MFMA_GFX940)
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x16i8; return MfmaInstr::mfma_i32_32x32x16i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x32i8; return MfmaInstr::mfma_i32_16x16x32i8;
} }
#else #else
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x8i8; return MfmaInstr::mfma_i32_32x32x8i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x16i8; return MfmaInstr::mfma_i32_16x16x16i8;
} }
#endif #endif
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() constexpr auto GetMfma<f8_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x16f8f8; return MfmaInstr::mfma_f32_32x32x16f8f8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 16, 16>() constexpr auto GetMfma<f8_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32>() constexpr auto GetMfma<bf8_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x16bf8bf8; return MfmaInstr::mfma_f32_32x32x16bf8bf8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 16, 16>() constexpr auto GetMfma<bf8_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x32bf8bf8; return MfmaInstr::mfma_f32_16x16x32bf8bf8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>() constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{ {
return MfmaInstr::mfma_f32_32x32x16f8bf8; return MfmaInstr::mfma_f32_32x32x16f8bf8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>() constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{ {
return MfmaInstr::mfma_f32_16x16x32f8bf8; return MfmaInstr::mfma_f32_16x16x32f8bf8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>() constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{ {
return MfmaInstr::mfma_f32_32x32x16bf8f8; return MfmaInstr::mfma_f32_32x32x16bf8f8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>() constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{ {
return MfmaInstr::mfma_f32_16x16x32bf8f8; return MfmaInstr::mfma_f32_16x16x32bf8f8;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -13,8 +13,24 @@ using int4_t = _BitInt(4); ...@@ -13,8 +13,24 @@ using int4_t = _BitInt(4);
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
inline constexpr auto next_pow2(uint32_t x)
{
// Precondition: x > 1.
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, f8_t>::value || is_same<T, bf8_t>::value ||
is_same<T, bool>::value;
}
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N, typename Enable = void>
struct vector_type; struct vector_type;
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
...@@ -171,7 +187,7 @@ struct scalar_type<bool> ...@@ -171,7 +187,7 @@ struct scalar_type<bool>
}; };
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using type = d1_t; using type = d1_t;
...@@ -189,7 +205,8 @@ struct vector_type<T, 1> ...@@ -189,7 +205,8 @@ struct vector_type<T, 1>
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_; return data_.d1x1_;
} }
...@@ -197,7 +214,8 @@ struct vector_type<T, 1> ...@@ -197,7 +214,8 @@ struct vector_type<T, 1>
template <typename X> template <typename X>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_; return data_.d1x1_;
} }
...@@ -205,7 +223,7 @@ struct vector_type<T, 1> ...@@ -205,7 +223,7 @@ struct vector_type<T, 1>
__device__ int static err = 0; __device__ int static err = 0;
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -226,7 +244,8 @@ struct vector_type<T, 2> ...@@ -226,7 +244,8 @@ struct vector_type<T, 2>
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -245,7 +264,8 @@ struct vector_type<T, 2> ...@@ -245,7 +264,8 @@ struct vector_type<T, 2>
template <typename X> template <typename X>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -263,7 +283,7 @@ struct vector_type<T, 2> ...@@ -263,7 +283,7 @@ struct vector_type<T, 2>
}; };
template <typename T> template <typename T>
struct vector_type<T, 4> struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -287,7 +307,7 @@ struct vector_type<T, 4> ...@@ -287,7 +307,7 @@ struct vector_type<T, 4>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value, static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -311,7 +331,7 @@ struct vector_type<T, 4> ...@@ -311,7 +331,7 @@ struct vector_type<T, 4>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value, static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -333,7 +353,7 @@ struct vector_type<T, 4> ...@@ -333,7 +353,7 @@ struct vector_type<T, 4>
}; };
template <typename T> template <typename T>
struct vector_type<T, 8> struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -360,7 +380,7 @@ struct vector_type<T, 8> ...@@ -360,7 +380,7 @@ struct vector_type<T, 8>
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value, is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -389,7 +409,7 @@ struct vector_type<T, 8> ...@@ -389,7 +409,7 @@ struct vector_type<T, 8>
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value, is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -415,7 +435,7 @@ struct vector_type<T, 8> ...@@ -415,7 +435,7 @@ struct vector_type<T, 8>
}; };
template <typename T> template <typename T>
struct vector_type<T, 16> struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -445,7 +465,7 @@ struct vector_type<T, 16> ...@@ -445,7 +465,7 @@ struct vector_type<T, 16>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value, is_same<X, d16_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -479,7 +499,7 @@ struct vector_type<T, 16> ...@@ -479,7 +499,7 @@ struct vector_type<T, 16>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value, is_same<X, d16_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -509,7 +529,7 @@ struct vector_type<T, 16> ...@@ -509,7 +529,7 @@ struct vector_type<T, 16>
}; };
template <typename T> template <typename T>
struct vector_type<T, 32> struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -541,7 +561,7 @@ struct vector_type<T, 32> ...@@ -541,7 +561,7 @@ struct vector_type<T, 32>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value, is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -579,7 +599,7 @@ struct vector_type<T, 32> ...@@ -579,7 +599,7 @@ struct vector_type<T, 32>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value, is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -613,7 +633,7 @@ struct vector_type<T, 32> ...@@ -613,7 +633,7 @@ struct vector_type<T, 32>
}; };
template <typename T> template <typename T>
struct vector_type<T, 64> struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -648,7 +668,7 @@ struct vector_type<T, 64> ...@@ -648,7 +668,7 @@ struct vector_type<T, 64>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value, is_same<X, d64_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -691,7 +711,7 @@ struct vector_type<T, 64> ...@@ -691,7 +711,7 @@ struct vector_type<T, 64>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value, is_same<X, d64_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -729,7 +749,7 @@ struct vector_type<T, 64> ...@@ -729,7 +749,7 @@ struct vector_type<T, 64>
}; };
template <typename T> template <typename T>
struct vector_type<T, 128> struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -766,7 +786,7 @@ struct vector_type<T, 128> ...@@ -766,7 +786,7 @@ struct vector_type<T, 128>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -813,7 +833,7 @@ struct vector_type<T, 128> ...@@ -813,7 +833,7 @@ struct vector_type<T, 128>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -855,7 +875,7 @@ struct vector_type<T, 128> ...@@ -855,7 +875,7 @@ struct vector_type<T, 128>
}; };
template <typename T> template <typename T>
struct vector_type<T, 256> struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -894,7 +914,7 @@ struct vector_type<T, 256> ...@@ -894,7 +914,7 @@ struct vector_type<T, 256>
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value || is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -945,7 +965,7 @@ struct vector_type<T, 256> ...@@ -945,7 +965,7 @@ struct vector_type<T, 256>
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value || is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -990,6 +1010,581 @@ struct vector_type<T, 256> ...@@ -990,6 +1010,581 @@ struct vector_type<T, 256>
} }
}; };
template <typename T, index_t N>
struct non_native_vector_base
{
using type = non_native_vector_base<T, N>;
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const type&) = default;
__host__ __device__ non_native_vector_base(type&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
T d[N];
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using type = d1_t;
union alignas(next_pow2(1 * sizeof(T)))
{
d1_t d1_;
StaticallyIndexedArray<d1_t, 1> d1x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_;
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_;
}
};
template <typename T>
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using type = d2_t;
union alignas(next_pow2(2 * sizeof(T)))
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using type = d4_t;
union alignas(next_pow2(4 * sizeof(T)))
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using type = d8_t;
union alignas(next_pow2(8 * sizeof(T)))
{
d8_t d8_;
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using type = d16_t;
union alignas(next_pow2(16 * sizeof(T)))
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using d32_t = non_native_vector_base<T, 32>;
using type = d32_t;
union alignas(next_pow2(32 * sizeof(T)))
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using d32_t = non_native_vector_base<T, 32>;
using d64_t = non_native_vector_base<T, 64>;
using type = d64_t;
union alignas(next_pow2(64 * sizeof(T)))
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
};
using int64_t = long; using int64_t = long;
// fp64 // fp64
...@@ -1051,8 +1646,8 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type; ...@@ -1051,8 +1646,8 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type; using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type; using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type; using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8 // u8
// i8
using uint8x2_t = typename vector_type<uint8_t, 2>::type; using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type; using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type; using uint8x8_t = typename vector_type<uint8_t, 8>::type;
......
...@@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x) ...@@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x) static inline __host__ bool isnan(int4_t x)
{ {
...@@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x) ...@@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __device__ half_t sqrt(half_t x) static inline __device__ half_t sqrt(half_t x)
{ {
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
......
...@@ -157,8 +157,11 @@ ...@@ -157,8 +157,11 @@
#endif #endif
#endif #endif
// workaround for ROCm 6.2 and later
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE #ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133 #if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3)
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1 #define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else #else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0 #define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment