Unverified Commit dda18da0 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into ck_migraphx_integration

parents 3b2a7aee 4cf70b36
...@@ -41,18 +41,39 @@ template <typename LayoutA, ...@@ -41,18 +41,39 @@ template <typename LayoutA,
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadA = true;
constexpr bool kPadB = true; constexpr bool kPadB = true;
constexpr bool kTilePermute = false;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>; // The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadA,
kPadB,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
auto kargs = Kernel::MakeKargs(args.p_a, auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b, args.p_b,
...@@ -255,15 +276,13 @@ int main(int argc, char* argv[]) ...@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
ck_tile::sequence<M_Warp, N_Warp, K_Warp>, ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>; ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem<ADataType, using CodegenGemmTraits = ck_tile::
BDataType, TileGemmTraits<kPadA, kPadB, kPadC, matrix_a_layout, matrix_b_layout, matrix_c_layout>;
AccDataType,
CodegenGemmShape, using CodegenPipelineProblem = ck_tile::
kPadA, GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
kPadB,
kPadC>;
using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
invoke_gemm<ck_tile::half_t, invoke_gemm<ck_tile::half_t,
matrix_a_layout, matrix_a_layout,
...@@ -341,7 +360,13 @@ int main(int argc, char* argv[]) ...@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions); ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c); a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
c_buf.FromDevice(c_host_gpu_ref.data()); c_buf.FromDevice(c_host_gpu_ref.data());
......
...@@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming ...@@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming
``` ```
# in the root of ck_tile # in the root of ck_tile
mkdir build && cd build mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942... # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_img2col -j make tile_example_img2col -j
``` ```
This will result in an executable `build/bin/tile_example_img2col` This will result in an executable `build/bin/tile_example_img2col`
...@@ -97,13 +97,6 @@ ...@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ #cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif #endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
// //
// CK kernels which support XDL (MI series) // CK kernels which support XDL (MI series)
// //
......
...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config, ...@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipEventElapsedTime(&total_time, start, stop)); hip_check_error(hipEventElapsedTime(&total_time, start, stop));
hip_check_error(hipEventDestroy(start));
hip_check_error(hipEventDestroy(stop));
return total_time / nrepeat; return total_time / nrepeat;
} }
else else
......
...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run( xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); c_thread_buf_per_scale.GetVectorTypeReference(I0));
...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr ...@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(), xdlops_gemm.template Run<>(
b_thread_vec.template AsType<mfma_input_type>(), a_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0)); b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
}); });
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset = constexpr index_t c_offset =
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "device_base.hpp" #include "device_base.hpp"
...@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator ...@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0; ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual std::size_t GetWorkspaceSize(index_t MRaw, virtual std::size_t GetWorkspaceSize(index_t MRaw,
index_t NRaw, index_t NRaw,
index_t KRaw, index_t KRaw,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC) = 0; index_t StrideC) const = 0;
}; };
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
......
...@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle ...@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
[[maybe_unused]] index_t K, [[maybe_unused]] index_t K,
[[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideA,
[[maybe_unused]] index_t StrideB, [[maybe_unused]] index_t StrideB,
index_t StrideC) override index_t StrideC) const override
{ {
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC); return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
} }
std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
{
const auto* parg = dynamic_cast<const Argument*>(base_arg);
if(!parg)
{
std::ostringstream err;
err << "Provided argument pointer is not of an Argument class!"
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
return GetWorkspaceSize(
parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
}
}; };
} // namespace device } // namespace device
......
...@@ -64,7 +64,7 @@ __global__ void ...@@ -64,7 +64,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
......
...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage ...@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t N = gemm_descs[i].N_; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_; const index_t K = gemm_descs[i].K_;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
skipped_group_count_++; skipped_group_count_++;
continue; continue;
......
...@@ -109,7 +109,7 @@ __global__ void ...@@ -109,7 +109,7 @@ __global__ void
N = gemm_desc_ptr[group_id].N; N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K; K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
{ {
grid_size_grp = 0; grid_size_grp = 0;
continue; continue;
......
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0) if(M == 0 || N == 0 || K == 0)
return; return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA; const auto StrideA = gemm_desc_ptr[group_id].StrideA;
......
...@@ -419,6 +419,12 @@ struct UnaryAbs ...@@ -419,6 +419,12 @@ struct UnaryAbs
y = ck::math::abs(x); y = ck::math::abs(x);
}; };
template <>
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const
{
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
};
}; };
struct UnarySqrt struct UnarySqrt
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -30,8 +30,24 @@ using int4_t = _BitInt(4); ...@@ -30,8 +30,24 @@ using int4_t = _BitInt(4);
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
inline constexpr auto next_pow2(uint32_t x)
{
// Precondition: x > 1.
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, f8_t>::value || is_same<T, bf8_t>::value ||
is_same<T, bool>::value;
}
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N, typename Enable = void>
struct vector_type; struct vector_type;
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
...@@ -188,7 +204,7 @@ struct scalar_type<bool> ...@@ -188,7 +204,7 @@ struct scalar_type<bool>
}; };
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using type = d1_t; using type = d1_t;
...@@ -206,7 +222,8 @@ struct vector_type<T, 1> ...@@ -206,7 +222,8 @@ struct vector_type<T, 1>
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_; return data_.d1x1_;
} }
...@@ -214,7 +231,8 @@ struct vector_type<T, 1> ...@@ -214,7 +231,8 @@ struct vector_type<T, 1>
template <typename X> template <typename X>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_; return data_.d1x1_;
} }
...@@ -222,7 +240,7 @@ struct vector_type<T, 1> ...@@ -222,7 +240,7 @@ struct vector_type<T, 1>
__device__ int static err = 0; __device__ int static err = 0;
template <typename T> template <typename T>
struct vector_type<T, 2> struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -243,7 +261,8 @@ struct vector_type<T, 2> ...@@ -243,7 +261,8 @@ struct vector_type<T, 2>
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -262,7 +281,8 @@ struct vector_type<T, 2> ...@@ -262,7 +281,8 @@ struct vector_type<T, 2>
template <typename X> template <typename X>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!"); static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -280,7 +300,7 @@ struct vector_type<T, 2> ...@@ -280,7 +300,7 @@ struct vector_type<T, 2>
}; };
template <typename T> template <typename T>
struct vector_type<T, 4> struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -304,7 +324,7 @@ struct vector_type<T, 4> ...@@ -304,7 +324,7 @@ struct vector_type<T, 4>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value, static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -328,7 +348,7 @@ struct vector_type<T, 4> ...@@ -328,7 +348,7 @@ struct vector_type<T, 4>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value, static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -350,7 +370,7 @@ struct vector_type<T, 4> ...@@ -350,7 +370,7 @@ struct vector_type<T, 4>
}; };
template <typename T> template <typename T>
struct vector_type<T, 8> struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -377,7 +397,7 @@ struct vector_type<T, 8> ...@@ -377,7 +397,7 @@ struct vector_type<T, 8>
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value, is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -406,7 +426,7 @@ struct vector_type<T, 8> ...@@ -406,7 +426,7 @@ struct vector_type<T, 8>
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value, is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -432,7 +452,7 @@ struct vector_type<T, 8> ...@@ -432,7 +452,7 @@ struct vector_type<T, 8>
}; };
template <typename T> template <typename T>
struct vector_type<T, 16> struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -462,7 +482,7 @@ struct vector_type<T, 16> ...@@ -462,7 +482,7 @@ struct vector_type<T, 16>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value, is_same<X, d16_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -496,7 +516,7 @@ struct vector_type<T, 16> ...@@ -496,7 +516,7 @@ struct vector_type<T, 16>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value, is_same<X, d16_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -526,7 +546,7 @@ struct vector_type<T, 16> ...@@ -526,7 +546,7 @@ struct vector_type<T, 16>
}; };
template <typename T> template <typename T>
struct vector_type<T, 32> struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -558,7 +578,7 @@ struct vector_type<T, 32> ...@@ -558,7 +578,7 @@ struct vector_type<T, 32>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value, is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -596,7 +616,7 @@ struct vector_type<T, 32> ...@@ -596,7 +616,7 @@ struct vector_type<T, 32>
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value, is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -630,7 +650,7 @@ struct vector_type<T, 32> ...@@ -630,7 +650,7 @@ struct vector_type<T, 32>
}; };
template <typename T> template <typename T>
struct vector_type<T, 64> struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -665,7 +685,7 @@ struct vector_type<T, 64> ...@@ -665,7 +685,7 @@ struct vector_type<T, 64>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value, is_same<X, d64_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -708,7 +728,7 @@ struct vector_type<T, 64> ...@@ -708,7 +728,7 @@ struct vector_type<T, 64>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value, is_same<X, d64_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -746,7 +766,7 @@ struct vector_type<T, 64> ...@@ -746,7 +766,7 @@ struct vector_type<T, 64>
}; };
template <typename T> template <typename T>
struct vector_type<T, 128> struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -783,7 +803,7 @@ struct vector_type<T, 128> ...@@ -783,7 +803,7 @@ struct vector_type<T, 128>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -830,7 +850,7 @@ struct vector_type<T, 128> ...@@ -830,7 +850,7 @@ struct vector_type<T, 128>
is_same<X, d4_t>::value || is_same<X, d8_t>::value || is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -872,7 +892,7 @@ struct vector_type<T, 128> ...@@ -872,7 +892,7 @@ struct vector_type<T, 128>
}; };
template <typename T> template <typename T>
struct vector_type<T, 256> struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -911,7 +931,7 @@ struct vector_type<T, 256> ...@@ -911,7 +931,7 @@ struct vector_type<T, 256>
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value || is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -962,7 +982,7 @@ struct vector_type<T, 256> ...@@ -962,7 +982,7 @@ struct vector_type<T, 256>
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value || is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value || is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value, is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"wrong!"); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value)
{ {
...@@ -1007,6 +1027,581 @@ struct vector_type<T, 256> ...@@ -1007,6 +1027,581 @@ struct vector_type<T, 256>
} }
}; };
template <typename T, index_t N>
struct non_native_vector_base
{
using type = non_native_vector_base<T, N>;
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const type&) = default;
__host__ __device__ non_native_vector_base(type&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
T d[N];
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using type = d1_t;
union alignas(next_pow2(1 * sizeof(T)))
{
d1_t d1_;
StaticallyIndexedArray<d1_t, 1> d1x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_;
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_;
}
};
template <typename T>
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using type = d2_t;
union alignas(next_pow2(2 * sizeof(T)))
{
d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x2_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using type = d4_t;
union alignas(next_pow2(4 * sizeof(T)))
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using type = d8_t;
union alignas(next_pow2(8 * sizeof(T)))
{
d8_t d8_;
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using type = d16_t;
union alignas(next_pow2(16 * sizeof(T)))
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using d32_t = non_native_vector_base<T, 32>;
using type = d32_t;
union alignas(next_pow2(32 * sizeof(T)))
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using d32_t = non_native_vector_base<T, 32>;
using d64_t = non_native_vector_base<T, 64>;
using type = d64_t;
union alignas(next_pow2(64 * sizeof(T)))
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
};
using int64_t = long; using int64_t = long;
// fp64 // fp64
...@@ -1068,8 +1663,8 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type; ...@@ -1068,8 +1663,8 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type; using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type; using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type; using bf8x64_t = typename vector_type<bf8_t, 64>::type;
// u8 // u8
// i8
using uint8x2_t = typename vector_type<uint8_t, 2>::type; using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type; using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type; using uint8x8_t = typename vector_type<uint8_t, 8>::type;
......
...@@ -81,6 +81,8 @@ static inline __host__ bool isnan(half_t x) ...@@ -81,6 +81,8 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x) static inline __host__ bool isnan(int4_t x)
{ {
...@@ -531,6 +533,8 @@ static inline __device__ bool isnan(half_t x) ...@@ -531,6 +533,8 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
static inline __device__ half_t sqrt(half_t x) static inline __device__ half_t sqrt(half_t x)
{ {
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
......
...@@ -58,7 +58,7 @@ struct thread_buffer { ...@@ -58,7 +58,7 @@ struct thread_buffer {
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); } template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
template <typename X_, template <typename X_,
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false> typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto _get_as() const CK_TILE_HOST_DEVICE constexpr auto _get_as() const
......
...@@ -50,12 +50,22 @@ class ArgParser ...@@ -50,12 +50,22 @@ class ArgParser
} }
return *this; return *this;
} }
void print() void print() const
{ {
// find max key length
std::string::size_type max_key_length = 11;
for(auto& key : keys)
{
if(max_key_length < key.length())
{
max_key_length = key.length();
}
}
printf("args:\n"); printf("args:\n");
for(auto& key : keys) for(auto& key : keys)
{ {
auto value = input_map[key]; auto value = input_map.at(key);
std::vector<std::string> help_text_lines; std::vector<std::string> help_text_lines;
size_t pos = 0; size_t pos = 0;
for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;) for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;)
...@@ -69,8 +79,7 @@ class ArgParser ...@@ -69,8 +79,7 @@ class ArgParser
std::string(value.help_text.begin() + pos, value.help_text.end())); std::string(value.help_text.begin() + pos, value.help_text.end()));
std::string default_value = std::string("(default:") + value.value + std::string(")"); std::string default_value = std::string("(default:") + value.value + std::string(")");
std::cout << std::setw(1 + max_key_length - value.name.length()) << "-" << key
std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key
<< std::setw(4) << " " << help_text_lines[0] << " " << default_value << std::setw(4) << " " << help_text_lines[0] << " " << default_value
<< std::endl; << std::endl;
...@@ -78,7 +87,8 @@ class ArgParser ...@@ -78,7 +87,8 @@ class ArgParser
help_next_line != help_text_lines.end(); help_next_line != help_text_lines.end();
++help_next_line) ++help_next_line)
{ {
std::cout << std::setw(17) << " " << *help_next_line << std::endl; std::cout << std::setw(1 + max_key_length + 4) << " " << *help_next_line
<< std::endl;
} }
} }
} }
......
...@@ -13,7 +13,6 @@ namespace conv { ...@@ -13,7 +13,6 @@ namespace conv {
struct ConvParam struct ConvParam
{ {
ConvParam();
ConvParam(ck_tile::index_t n_dim, ConvParam(ck_tile::index_t n_dim,
ck_tile::index_t group_count, ck_tile::index_t group_count,
ck_tile::index_t n_batch, ck_tile::index_t n_batch,
...@@ -199,11 +198,6 @@ struct ConvParam ...@@ -199,11 +198,6 @@ struct ConvParam
} }
}; };
ConvParam::ConvParam()
: ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
{
}
CK_TILE_HOST std::string get_conv_param_parser_helper_msg() CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
{ {
std::string msg; std::string msg;
......
...@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const BElementOp& b_element_op = {}, const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {}) const ACCElementOp& acc_element_op = {})
{ {
const int N = b_n_k.mDesc.get_lengths()[0]; const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_n_k.mDesc.get_lengths()[0]
: b_n_k.mDesc.get_lengths()[1];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1] ? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0]; : a_m_k.mDesc.get_lengths()[0];
...@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>) ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k)) ? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m)); : a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k)); BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? b_element_op(b_n_k(n, k))
: b_element_op(b_n_k(k, n));
v_acc += ck_tile::type_convert<AccDataType>(v_a) * v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b); ck_tile::type_convert<AccDataType>(v_b);
} }
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc)); CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? c_m_n(m, n)
: c_m_n(n, m);
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
} }
}; };
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void naive_gemm_kernel(ADataType* A, __global__ void naive_gemm_kernel(ADataType* A,
BDataType* B, BDataType* B,
CDataType* C, CDataType* C,
...@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A, ...@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
if(row < M && col < N) if(row < M && col < N)
{ {
AccDataType acc = 0.0; AccDataType acc = 0.0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
acc += static_cast<AccDataType>(A[row * strideA + k]) * // Adjust indexing based on matrix layout
static_cast<AccDataType>(B[col * strideB + k]); int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
} }
C[row * strideC + col] = acc; // Store as AccDataType int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = acc;
} }
} }
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType> template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_gemm_gpu(DeviceMem& a_device, void reference_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device, DeviceMem& b_device,
DeviceMem& c_device, DeviceMem& c_device,
...@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
int numThreadsPerBlock = 256; // Common choice for threads per block int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType> naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); <<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
errC = hipMemcpy( errC = hipMemcpy(
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment