Commit 3dc5db72 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents b924e330 e547c141
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "device_base.hpp" #include "device_base.hpp"
...@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator ...@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0; ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::size_t GetWorkspaceSize(index_t MRaw, virtual std::size_t GetWorkspaceSize(index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC) = 0; index_t StrideC) const = 0;
}; };
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
......
...@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[maybe_unused]] index_t K, [[maybe_unused]] index_t K,
[[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideA,
[[maybe_unused]] index_t StrideB, [[maybe_unused]] index_t StrideB,
index_t StrideC) override index_t StrideC) const override
{ {
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC); return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
} }
std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
{
const auto* parg = dynamic_cast<const Argument*>(base_arg);
if(!parg)
{
std::ostringstream err;
err << "Provided argument pointer is not of an Argument class!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return GetWorkspaceSize(
parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
}
}; };
} // namespace device } // namespace device
......
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
......
...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
skipped_group_count_++; skipped_group_count_++;
continue; continue;
......
...@@ -109,7 +109,7 @@ __global__ void ...@@ -109,7 +109,7 @@ __global__ void
N = gemm_desc_ptr[group_id].N; N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K; K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
grid_size_grp = 0; grid_size_grp = 0;
continue; continue;
......
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
......
...@@ -419,6 +419,12 @@ struct UnaryAbs ...@@ -419,6 +419,12 @@ struct UnaryAbs
y = ck::math::abs(x); y = ck::math::abs(x);
}; };
template <>
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const
{
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
};
}; };
struct UnarySqrt struct UnarySqrt
......
...@@ -324,55 +324,55 @@ struct DppSelector ...@@ -324,55 +324,55 @@ struct DppSelector
static constexpr auto GetDpp(); static constexpr auto GetDpp();
template <> template <>
static constexpr auto GetDpp<half_t, 8, 32>() constexpr auto GetDpp<half_t, 8, 32>()
{ {
return DppInstr::dpp8_f16_8x32x2; return DppInstr::dpp8_f16_8x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 8, 16>() constexpr auto GetDpp<half_t, 8, 16>()
{ {
return DppInstr::dpp8_f16_8x16x2; return DppInstr::dpp8_f16_8x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 16, 16>() constexpr auto GetDpp<half_t, 16, 16>()
{ {
return DppInstr::dpp8_f16_16x16x2; return DppInstr::dpp8_f16_16x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 32, 8>() constexpr auto GetDpp<half_t, 32, 8>()
{ {
return DppInstr::dpp8_f16_32x8x2; return DppInstr::dpp8_f16_32x8x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 1, 32>() constexpr auto GetDpp<half_t, 1, 32>()
{ {
return DppInstr::dpp8_f16_1x32x2; return DppInstr::dpp8_f16_1x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 2, 32>() constexpr auto GetDpp<half_t, 2, 32>()
{ {
return DppInstr::dpp8_f16_2x32x2; return DppInstr::dpp8_f16_2x32x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 2, 16>() constexpr auto GetDpp<half_t, 2, 16>()
{ {
return DppInstr::dpp8_f16_2x16x2; return DppInstr::dpp8_f16_2x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 4, 16>() constexpr auto GetDpp<half_t, 4, 16>()
{ {
return DppInstr::dpp8_f16_4x16x2; return DppInstr::dpp8_f16_4x16x2;
} }
template <> template <>
static constexpr auto GetDpp<half_t, 4, 32>() constexpr auto GetDpp<half_t, 4, 32>()
{ {
return DppInstr::dpp8_f16_4x32x2; return DppInstr::dpp8_f16_4x32x2;
} }
......
...@@ -415,7 +415,7 @@ struct WmmaSelector ...@@ -415,7 +415,7 @@ struct WmmaSelector
static constexpr auto GetWmma(); static constexpr auto GetWmma();
template <> template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>() constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
...@@ -425,7 +425,7 @@ struct WmmaSelector ...@@ -425,7 +425,7 @@ struct WmmaSelector
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>() constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
...@@ -435,19 +435,19 @@ struct WmmaSelector ...@@ -435,19 +435,19 @@ struct WmmaSelector
} }
template <> template <>
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>() constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{ {
return WmmaInstr::wmma_f16_16x16x16_f16; return WmmaInstr::wmma_f16_16x16x16_f16;
} }
template <> template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>() constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{ {
return WmmaInstr::wmma_bf16_16x16x16_bf16; return WmmaInstr::wmma_bf16_16x16x16_bf16;
} }
template <> template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>() constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{ {
#ifdef __gfx12__ #ifdef __gfx12__
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
...@@ -458,7 +458,7 @@ struct WmmaSelector ...@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <> template <>
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>() constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{ {
return WmmaInstr::wmma_i32_16x16x16_iu4; return WmmaInstr::wmma_i32_16x16x16_iu4;
} }
......
...@@ -651,97 +651,97 @@ struct MfmaSelector ...@@ -651,97 +651,97 @@ struct MfmaSelector
static constexpr auto GetMfma(); static constexpr auto GetMfma();
template <> template <>
static constexpr auto GetMfma<double, 16, 16>() constexpr auto GetMfma<double, 16, 16>()
{ {
return MfmaInstr::mfma_f64_16x16x4f64; return MfmaInstr::mfma_f64_16x16x4f64;
} }
template <> template <>
static constexpr auto GetMfma<float, 64, 64>() constexpr auto GetMfma<float, 64, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x1xf32; return MfmaInstr::mfma_f32_32x32x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 32, 64>() constexpr auto GetMfma<float, 32, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x1xf32; return MfmaInstr::mfma_f32_32x32x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 16, 64>() constexpr auto GetMfma<float, 16, 64>()
{ {
return MfmaInstr::mfma_f32_16x16x1xf32; return MfmaInstr::mfma_f32_16x16x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 8, 64>() constexpr auto GetMfma<float, 8, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x1xf32; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 4, 64>() constexpr auto GetMfma<float, 4, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x1xf32; return MfmaInstr::mfma_f32_4x4x1xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 32, 32>() constexpr auto GetMfma<float, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x2xf32; return MfmaInstr::mfma_f32_32x32x2xf32;
} }
template <> template <>
static constexpr auto GetMfma<float, 16, 16>() constexpr auto GetMfma<float, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x4xf32; return MfmaInstr::mfma_f32_16x16x4xf32;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 64, 64>() constexpr auto GetMfma<half_t, 64, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x4f16; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 32, 64>() constexpr auto GetMfma<half_t, 32, 64>()
{ {
return MfmaInstr::mfma_f32_32x32x4f16; return MfmaInstr::mfma_f32_32x32x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 32, 32>() constexpr auto GetMfma<half_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x8f16; return MfmaInstr::mfma_f32_32x32x8f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 16, 16>() constexpr auto GetMfma<half_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x16f16; return MfmaInstr::mfma_f32_16x16x16f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 16, 64>() constexpr auto GetMfma<half_t, 16, 64>()
{ {
return MfmaInstr::mfma_f32_16x16x4f16; return MfmaInstr::mfma_f32_16x16x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 8, 64>() constexpr auto GetMfma<half_t, 8, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetMfma<half_t, 4, 64>() constexpr auto GetMfma<half_t, 4, 64>()
{ {
return MfmaInstr::mfma_f32_4x4x4f16; return MfmaInstr::mfma_f32_4x4x4f16;
} }
template <> template <>
static constexpr auto GetMfma<bhalf_t, 32, 32>() constexpr auto GetMfma<bhalf_t, 32, 32>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k; return MfmaInstr::mfma_f32_32x32x8bf16_1k;
...@@ -751,7 +751,7 @@ struct MfmaSelector ...@@ -751,7 +751,7 @@ struct MfmaSelector
} }
template <> template <>
static constexpr auto GetMfma<bhalf_t, 16, 16>() constexpr auto GetMfma<bhalf_t, 16, 16>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k; return MfmaInstr::mfma_f32_16x16x16bf16_1k;
...@@ -762,72 +762,72 @@ struct MfmaSelector ...@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940) #if defined(CK_USE_AMD_MFMA_GFX940)
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x16i8; return MfmaInstr::mfma_i32_32x32x16i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x32i8; return MfmaInstr::mfma_i32_16x16x32i8;
} }
#else #else
template <> template <>
static constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
return MfmaInstr::mfma_i32_32x32x8i8; return MfmaInstr::mfma_i32_32x32x8i8;
} }
template <> template <>
static constexpr auto GetMfma<int8_t, 16, 16>() constexpr auto GetMfma<int8_t, 16, 16>()
{ {
return MfmaInstr::mfma_i32_16x16x16i8; return MfmaInstr::mfma_i32_16x16x16i8;
} }
#endif #endif
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32>() constexpr auto GetMfma<f8_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x16f8f8; return MfmaInstr::mfma_f32_32x32x16f8f8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 16, 16>() constexpr auto GetMfma<f8_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x32f8f8; return MfmaInstr::mfma_f32_16x16x32f8f8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32>() constexpr auto GetMfma<bf8_t, 32, 32>()
{ {
return MfmaInstr::mfma_f32_32x32x16bf8bf8; return MfmaInstr::mfma_f32_32x32x16bf8bf8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 16, 16>() constexpr auto GetMfma<bf8_t, 16, 16>()
{ {
return MfmaInstr::mfma_f32_16x16x32bf8bf8; return MfmaInstr::mfma_f32_16x16x32bf8bf8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>() constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{ {
return MfmaInstr::mfma_f32_32x32x16f8bf8; return MfmaInstr::mfma_f32_32x32x16f8bf8;
} }
template <> template <>
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>() constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{ {
return MfmaInstr::mfma_f32_16x16x32f8bf8; return MfmaInstr::mfma_f32_16x16x32f8bf8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>() constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{ {
return MfmaInstr::mfma_f32_32x32x16bf8f8; return MfmaInstr::mfma_f32_32x32x16bf8f8;
} }
template <> template <>
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>() constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{ {
return MfmaInstr::mfma_f32_16x16x32bf8f8; return MfmaInstr::mfma_f32_16x16x32bf8f8;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -13,8 +13,24 @@ using int4_t = _BitInt(4); ...@@ -13,8 +13,24 @@ using int4_t = _BitInt(4);
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
inline constexpr auto next_pow2(uint32_t x)
{
// Precondition: x > 1.
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, f8_t>::value || is_same<T, bf8_t>::value ||
is_same<T, bool>::value;
}
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N, typename Enable = void>
struct vector_type; struct vector_type;
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
...@@ -171,7 +187,7 @@ struct scalar_type<bool> ...@@ -171,7 +187,7 @@ struct scalar_type<bool>
}; };
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using type = d1_t; using type = d1_t;
...@@ -189,7 +205,8 @@ struct vector_type<T, 1> ...@@ -189,7 +205,8 @@ struct vector_type<T, 1>
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_; return data_.d1x1_;
} }
...@@ -197,7 +214,8 @@ struct vector_type<T, 1> ...@@ -197,7 +214,8 @@ struct vector_type<T, 1>
template <typename X> template <typename X>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_; return data_.d1x1_;
} }
...@@ -205,7 +223,7 @@ struct vector_type<T, 1> ...@@ -205,7 +223,7 @@ struct vector_type<T, 1>
__device__ int static err = 0; __device__ int static err = 0;
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -226,7 +244,8 @@ struct vector_type<T, 2> ...@@ -226,7 +244,8 @@ struct vector_type<T, 2>
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -245,7 +264,8 @@ struct vector_type<T, 2> ...@@ -245,7 +264,8 @@ struct vector_type<T, 2>
template <typename X> template <typename X>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -263,7 +283,7 @@ struct vector_type<T, 2> ...@@ -263,7 +283,7 @@ struct vector_type<T, 2>
}; };
template <typename T> template <typename T>
struct vector_type<T, 4> struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -287,7 +307,7 @@ struct vector_type<T, 4> ...@@ -287,7 +307,7 @@ struct vector_type<T, 4>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value, static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -311,7 +331,7 @@ struct vector_type<T, 4> ...@@ -311,7 +331,7 @@ struct vector_type<T, 4>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value, static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -333,7 +353,7 @@ struct vector_type<T, 4> ...@@ -333,7 +353,7 @@ struct vector_type<T, 4>
}; };
template <typename T> template <typename T>
struct vector_type<T, 8> struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -360,7 +380,7 @@ struct vector_type<T, 8> ...@@ -360,7 +380,7 @@ struct vector_type<T, 8>
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value, is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -389,7 +409,7 @@ struct vector_type<T, 8> ...@@ -389,7 +409,7 @@ struct vector_type<T, 8>
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value, is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -415,7 +435,7 @@ struct vector_type<T, 8> ...@@ -415,7 +435,7 @@ struct vector_type<T, 8>
}; };
template <typename T> template <typename T>
struct vector_type<T, 16> struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -445,7 +465,7 @@ struct vector_type<T, 16> ...@@ -445,7 +465,7 @@ struct vector_type<T, 16>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value, is_same<X, d16_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -479,7 +499,7 @@ struct vector_type<T, 16> ...@@ -479,7 +499,7 @@ struct vector_type<T, 16>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value, is_same<X, d16_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -509,7 +529,7 @@ struct vector_type<T, 16> ...@@ -509,7 +529,7 @@ struct vector_type<T, 16>
}; };
template <typename T> template <typename T>
struct vector_type<T, 32> struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -541,7 +561,7 @@ struct vector_type<T, 32> ...@@ -541,7 +561,7 @@ struct vector_type<T, 32>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value, is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -579,7 +599,7 @@ struct vector_type<T, 32> ...@@ -579,7 +599,7 @@ struct vector_type<T, 32>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value, is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -613,7 +633,7 @@ struct vector_type<T, 32> ...@@ -613,7 +633,7 @@ struct vector_type<T, 32>
}; };
template <typename T> template <typename T>
struct vector_type<T, 64> struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -648,7 +668,7 @@ struct vector_type<T, 64> ...@@ -648,7 +668,7 @@ struct vector_type<T, 64>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value, is_same<X, d64_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -691,7 +711,7 @@ struct vector_type<T, 64> ...@@ -691,7 +711,7 @@ struct vector_type<T, 64>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value, is_same<X, d64_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -729,7 +749,7 @@ struct vector_type<T, 64> ...@@ -729,7 +749,7 @@ struct vector_type<T, 64>
}; };
template <typename T> template <typename T>
struct vector_type<T, 128> struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -766,7 +786,7 @@ struct vector_type<T, 128> ...@@ -766,7 +786,7 @@ struct vector_type<T, 128>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -813,7 +833,7 @@ struct vector_type<T, 128> ...@@ -813,7 +833,7 @@ struct vector_type<T, 128>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -855,7 +875,7 @@ struct vector_type<T, 128> ...@@ -855,7 +875,7 @@ struct vector_type<T, 128>
}; };
template <typename T> template <typename T>
struct vector_type<T, 256> struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -894,7 +914,7 @@ struct vector_type<T, 256> ...@@ -894,7 +914,7 @@ struct vector_type<T, 256>
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value || is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -945,7 +965,7 @@ struct vector_type<T, 256> ...@@ -945,7 +965,7 @@ struct vector_type<T, 256>
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value || is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -990,6 +1010,581 @@ struct vector_type<T, 256> ...@@ -990,6 +1010,581 @@ struct vector_type<T, 256>
} }
}; };
template <typename T, index_t N>
struct non_native_vector_base
{
using type = non_native_vector_base<T, N>;
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const type&) = default;
__host__ __device__ non_native_vector_base(type&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
T d[N];
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using type = d1_t;
union alignas(next_pow2(1 * sizeof(T)))
{
d1_t d1_;
StaticallyIndexedArray<d1_t, 1> d1x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_;
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_;
}
};
template <typename T>
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using type = d2_t;
union alignas(next_pow2(2 * sizeof(T)))
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using type = d4_t;
union alignas(next_pow2(4 * sizeof(T)))
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using type = d8_t;
union alignas(next_pow2(8 * sizeof(T)))
{
d8_t d8_;
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using type = d16_t;
union alignas(next_pow2(16 * sizeof(T)))
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using d32_t = non_native_vector_base<T, 32>;
using type = d32_t;
union alignas(next_pow2(32 * sizeof(T)))
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using d32_t = non_native_vector_base<T, 32>;
using d64_t = non_native_vector_base<T, 64>;
using type = d64_t;
union alignas(next_pow2(64 * sizeof(T)))
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
};
using int64_t = long; using int64_t = long;
// fp64 // fp64
...@@ -1051,8 +1646,8 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type; ...@@ -1051,8 +1646,8 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type; using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type; using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type; using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8 // u8
// i8
using uint8x2_t = typename vector_type<uint8_t, 2>::type; using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type; using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type; using uint8x8_t = typename vector_type<uint8_t, 8>::type;
......
...@@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x) ...@@ -80,6 +80,8 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x) static inline __host__ bool isnan(int4_t x)
{ {
...@@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x) ...@@ -529,6 +531,8 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __device__ half_t sqrt(half_t x) static inline __device__ half_t sqrt(half_t x)
{ {
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <initializer_list> #include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp" #include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
...@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr ...@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
return !(a == b); return !(a == b);
} }
template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector<X>& x)
{
array<T, N> arr;
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
return arr;
}
template <typename T, index_t N, typename X> template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x) CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
{ {
......
...@@ -58,7 +58,7 @@ struct thread_buffer { ...@@ -58,7 +58,7 @@ struct thread_buffer {
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
template <typename X_, template <typename X_,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false> typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto _get_as() const CK_TILE_HOST_DEVICE constexpr auto _get_as() const
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include "ck_tile/host/arg_parser.hpp" #include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp" #include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp" #include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/fill.hpp" #include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/hip_check_error.hpp"
......
...@@ -50,12 +50,22 @@ class ArgParser ...@@ -50,12 +50,22 @@ class ArgParser
} }
return *this; return *this;
} }
void print() void print() const
{ {
// find max key length
std::string::size_type max_key_length = 11;
for(auto& key : keys)
{
if(max_key_length < key.length())
{
max_key_length = key.length();
}
}
printf("args:\n"); printf("args:\n");
for(auto& key : keys) for(auto& key : keys)
{ {
auto value = input_map[key]; auto value = input_map.at(key);
std::vector<std::string> help_text_lines; std::vector<std::string> help_text_lines;
size_t pos = 0; size_t pos = 0;
for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;) for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;)
...@@ -69,8 +79,7 @@ class ArgParser ...@@ -69,8 +79,7 @@ class ArgParser
std::string(value.help_text.begin() + pos, value.help_text.end())); std::string(value.help_text.begin() + pos, value.help_text.end()));
std::string default_value = std::string("(default:") + value.value + std::string(")"); std::string default_value = std::string("(default:") + value.value + std::string(")");
std::cout << std::setw(1 + max_key_length - value.name.length()) << "-" << key
std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key
<< std::setw(4) << " " << help_text_lines[0] << " " << default_value << std::setw(4) << " " << help_text_lines[0] << " " << default_value
<< std::endl; << std::endl;
...@@ -78,7 +87,8 @@ class ArgParser ...@@ -78,7 +87,8 @@ class ArgParser
help_next_line != help_text_lines.end(); help_next_line != help_text_lines.end();
++help_next_line) ++help_next_line)
{ {
std::cout << std::setw(17) << " " << *help_next_line << std::endl; std::cout << std::setw(1 + max_key_length + 4) << " " << *help_next_line
<< std::endl;
} }
} }
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
namespace conv {
namespace detail {
template <typename OldLayout>
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>)
{
return {0, 1, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
{
return {0, 1, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
{
return {0, 1, 2, 3, 4, 5};
}
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
{
return {0, 1, 3, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
{
return {0, 1, 4, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
{
return {0, 1, 5, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
{
return {2, 0, 3, 1};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
{
return {3, 0, 4, 1, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
{
return {4, 0, 5, 1, 2, 3};
}
else
{
printf("%s\n", __func__);
throw std::runtime_error("wrong! unsupported layout");
}
}
} // namespace detail
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
// regardless of physical layout
template <typename InLayout>
CK_TILE_HOST HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", InLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<InLayout>());
}
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
// regardless of physical layout
template <typename WeiLayout>
CK_TILE_HOST HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", WeiLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<WeiLayout>());
}
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
// regardless of physical layout
template <typename OutLayout>
CK_TILE_HOST HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.end(),
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", OutLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<OutLayout>());
}
} // namespace conv
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>
namespace ck_tile {
namespace conv {
struct ConvParam
{
ConvParam(ck_tile::index_t n_dim,
ck_tile::index_t group_count,
ck_tile::index_t n_batch,
ck_tile::index_t n_out_channels,
ck_tile::index_t n_in_channels,
const std::vector<ck_tile::index_t>& filters_len,
const std::vector<ck_tile::index_t>& input_len,
const std::vector<ck_tile::index_t>& strides,
const std::vector<ck_tile::index_t>& dilations,
const std::vector<ck_tile::index_t>& left_pads,
const std::vector<ck_tile::index_t>& right_pads)
: num_dim_spatial_(static_cast<ck_tile::long_index_t>(n_dim)),
G_(static_cast<ck_tile::long_index_t>(group_count)),
N_(static_cast<ck_tile::long_index_t>(n_batch)),
K_(static_cast<ck_tile::long_index_t>(n_out_channels)),
C_(static_cast<ck_tile::long_index_t>(n_in_channels)),
filter_spatial_lengths_(num_dim_spatial_),
input_spatial_lengths_(num_dim_spatial_),
output_spatial_lengths_(num_dim_spatial_),
conv_filter_strides_(num_dim_spatial_),
conv_filter_dilations_(num_dim_spatial_),
input_left_pads_(num_dim_spatial_),
input_right_pads_(num_dim_spatial_)
{
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
{
throw(std::runtime_error(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"));
}
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
{
filter_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(filters_len[i]);
input_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(input_len[i]);
conv_filter_strides_[i] = static_cast<ck_tile::long_index_t>(strides[i]);
conv_filter_dilations_[i] = static_cast<ck_tile::long_index_t>(dilations[i]);
input_left_pads_[i] = static_cast<ck_tile::long_index_t>(left_pads[i]);
input_right_pads_[i] = static_cast<ck_tile::long_index_t>(right_pads[i]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck_tile::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] +
1;
}
}
ConvParam(ck_tile::long_index_t n_dim,
ck_tile::long_index_t group_count,
ck_tile::long_index_t n_batch,
ck_tile::long_index_t n_out_channels,
ck_tile::long_index_t n_in_channels,
const std::vector<ck_tile::long_index_t>& filters_len,
const std::vector<ck_tile::long_index_t>& input_len,
const std::vector<ck_tile::long_index_t>& strides,
const std::vector<ck_tile::long_index_t>& dilations,
const std::vector<ck_tile::long_index_t>& left_pads,
const std::vector<ck_tile::long_index_t>& right_pads)
: num_dim_spatial_(n_dim),
G_(group_count),
N_(n_batch),
K_(n_out_channels),
C_(n_in_channels),
filter_spatial_lengths_(filters_len),
input_spatial_lengths_(input_len),
output_spatial_lengths_(num_dim_spatial_),
conv_filter_strides_(strides),
conv_filter_dilations_(dilations),
input_left_pads_(left_pads),
input_right_pads_(right_pads)
{
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
{
throw(std::runtime_error(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"));
}
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck_tile::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] +
1;
}
}
ck_tile::long_index_t num_dim_spatial_;
ck_tile::long_index_t G_;
ck_tile::long_index_t N_;
ck_tile::long_index_t K_;
ck_tile::long_index_t C_;
std::vector<ck_tile::long_index_t> filter_spatial_lengths_;
std::vector<ck_tile::long_index_t> input_spatial_lengths_;
std::vector<ck_tile::long_index_t> output_spatial_lengths_;
std::vector<ck_tile::long_index_t> conv_filter_strides_;
std::vector<ck_tile::long_index_t> conv_filter_dilations_;
std::vector<ck_tile::long_index_t> input_left_pads_;
std::vector<ck_tile::long_index_t> input_right_pads_;
std::vector<ck_tile::long_index_t> GetOutputSpatialLengths() const
{
return output_spatial_lengths_;
}
std::size_t GetFlops() const
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::next(std::begin(output_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()) *
std::accumulate(std::begin(filter_spatial_lengths_),
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>());
}
template <typename InDataType>
std::size_t GetInputByte() const
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) *
(G_ * N_ * C_ *
std::accumulate(std::begin(input_spatial_lengths_),
std::next(std::begin(input_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()));
}
template <typename WeiDataType>
std::size_t GetWeightByte() const
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) *
(G_ * K_ * C_ *
std::accumulate(std::begin(filter_spatial_lengths_),
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()));
}
template <typename OutDataType>
std::size_t GetOutputByte() const
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * (G_ * N_ * K_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
template <typename InDataType, typename WeiDataType, typename OutDataType>
std::size_t GetByte() const
{
return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
GetOutputByte<OutDataType>();
}
};
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
{
std::string msg;
msg += "Following arguments (depending on number of spatial dims):\n"
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
" G, N, K, C, \n"
" <filter spatial dimensions>, (ie Y, X for 2D)\n"
" <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
" <strides>, (ie Sy, Sx for 2D)\n"
" <dilations>, (ie Dy, Dx for 2D)\n"
" <left padding>, (ie LeftPy, LeftPx for 2D)\n"
" <right padding>, (ie RightPy, RightPx for 2D)\n";
return msg;
}
CK_TILE_HOST ck_tile::conv::ConvParam
parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck_tile::long_index_t G = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t N = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t K = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t C = std::stol(argv[arg_idx++]);
std::vector<ck_tile::long_index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck_tile::long_index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck_tile::long_index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_left_pads(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stol(argv[arg_idx++]);
}
return ck_tile::conv::ConvParam{num_dim_spatial,
G,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
} // namespace conv
} // namespace ck_tile
...@@ -176,7 +176,20 @@ struct HostTensorDescriptor ...@@ -176,7 +176,20 @@ struct HostTensorDescriptor
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
} }
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
{
os << "dim " << desc.get_num_of_dimension() << ", ";
os << "lengths {";
LogRange(os, desc.get_lengths(), ", ");
os << "}, ";
os << "strides {";
LogRange(os, desc.get_strides(), ", ");
os << "}";
return os;
}
private: private:
std::vector<std::size_t> mLens; std::vector<std::size_t> mLens;
......
...@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const BElementOp& b_element_op = {}, const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {}) const ACCElementOp& acc_element_op = {})
{ {
const int N = b_n_k.mDesc.get_lengths()[0]; const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_n_k.mDesc.get_lengths()[0]
: b_n_k.mDesc.get_lengths()[1];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1] ? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0]; : a_m_k.mDesc.get_lengths()[0];
...@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k)) ? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m)); : a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k)); BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_element_op(b_n_k(n, k))
: b_element_op(b_n_k(k, n));
v_acc += ck_tile::type_convert<AccDataType>(v_a) * v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b); ck_tile::type_convert<AccDataType>(v_b);
} }
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc)); CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? c_m_n(m, n)
: c_m_n(n, m);
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
} }
}; };
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType* A, __global__ void naive_gemm_kernel(ADataType* A,
BDataType* B, BDataType* B,
CDataType* C, CDataType* C,
...@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A, ...@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if(row < M && col < N) if(row < M && col < N)
{ {
AccDataType acc = 0.0; AccDataType acc = 0.0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
acc += static_cast<AccDataType>(A[row * strideA + k]) * // Adjust indexing based on matrix layout
static_cast<AccDataType>(B[col * strideB + k]); int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
} }
C[row * strideC + col] = acc; // Store as AccDataType int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = acc;
} }
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(DeviceMem& a_device, void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device, DeviceMem& b_device,
DeviceMem& c_device, DeviceMem& c_device,
...@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int numThreadsPerBlock = 256; // Common choice for threads per block int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType> naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); <<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy( errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment