"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "ae86bfd9ed457514662c04c4f10f7aaf536d85ea"
Commit 33d1e0e2 authored by Chao Liu's avatar Chao Liu
Browse files

refactoring for miopen

parent b1cb48a0
...@@ -71,24 +71,7 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -71,24 +71,7 @@ __device__ void threadwise_gemm(MatrixA,
integral_constant<bool, TransC>, integral_constant<bool, TransC>,
FloatC* __restrict__ p_c_thread) FloatC* __restrict__ p_c_thread)
{ {
#if 0 static_if<TransA && (!TransB) && (!TransC)>{}([&](auto fwd) {
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("p_a_thread: %f %f %f %f\n",
p_a_thread[0],
p_a_thread[1],
p_a_thread[2],
p_a_thread[3]);
printf("p_b_thread: %f %f %f %f\n",
p_b_thread[0],
p_b_thread[1],
p_b_thread[2],
p_b_thread[3]);
}
#endif
if(TransA && (!TransB) && (!TransC))
{
constexpr auto a_mtx = MatrixA{}; constexpr auto a_mtx = MatrixA{};
constexpr auto b_mtx = MatrixB{}; constexpr auto b_mtx = MatrixB{};
constexpr auto c_mtx = MatrixC{}; constexpr auto c_mtx = MatrixC{};
...@@ -111,12 +94,10 @@ __device__ void threadwise_gemm(MatrixA, ...@@ -111,12 +94,10 @@ __device__ void threadwise_gemm(MatrixA,
} }
} }
} }
} }).Else([&](auto fwd) {
else
{
// not implemented // not implemented
assert(false); static_assert(fwd(false), "wrong! support for this config is not implemented");
} });
} }
} // namespace ck } // namespace ck
......
...@@ -5,6 +5,10 @@ ...@@ -5,6 +5,10 @@
#include "ConstantTensorDescriptor.hpp" #include "ConstantTensorDescriptor.hpp"
#include "ConstantMergedTensorDescriptor.hpp" #include "ConstantMergedTensorDescriptor.hpp"
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
#endif
namespace ck { namespace ck {
template <class Float, template <class Float,
...@@ -32,21 +36,18 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( ...@@ -32,21 +36,18 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
static_assert(is_valid_sequence_map<DimAccessOrder>::value, "wrong! map is not valid"); static_assert(is_valid_sequence_map<DimAccessOrder>::value, "wrong! map is not valid");
#if 0 // TODO: do more sanity-check here, something like:
// doesn't compile, because merged-tensor reordering is not implemented // constexpr auto src_strides_in_access_order =
// TODO: implement tensor desc ops for merged-tensor // SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
constexpr auto src_strides_in_access_order =
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
constexpr auto dst_strides_in_access_order = // constexpr auto dst_strides_in_access_order =
SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{}); // SrcDesc::ReorderGivenNew2Old(DimAccessOrder{}).GetStride(Number<nDim-1>{});
// check src/dst stride on the lowest access dimension // // check src/dst stride on the lowest access dimension
static_assert((DataPerAccess == 1 || src_strides_in_access_order.Back() == 1) && // static_assert((DataPerAccess == 1 || src_strides_in_access_order.Back() == 1) &&
(DataPerAccess == 1 || dst_strides_in_access_order.Back() == 1), // (DataPerAccess == 1 || dst_strides_in_access_order.Back() == 1),
"wrong! src/dst stride on the lowest access dimension needs to be 1 for " // "wrong! src/dst stride on the lowest access dimension needs to be 1 for "
"vectorized read/write"); // "vectorized read/write");
#endif
constexpr auto slice_lengths_in_access_order = constexpr auto slice_lengths_in_access_order =
SliceLengths::ReorderGivenNew2Old(DimAccessOrder{}); SliceLengths::ReorderGivenNew2Old(DimAccessOrder{});
...@@ -64,13 +65,15 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( ...@@ -64,13 +65,15 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
using vector_t = typename vector_type<Float, DataPerAccess>::MemoryType; using vector_t = typename vector_type<Float, DataPerAccess>::MemoryType;
#if 1 #if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1
ford<decltype(access_lengths)>{}([&](auto access_multi_id) { static_ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
auto data_multi_id_in_access_order = access_multi_id; constexpr index_t itmp = access_multi_id.Back() * DataPerAccess;
data_multi_id_in_access_order(nDim - 1) = access_multi_id[nDim - 1] * DataPerAccess;
const auto data_multi_id = constexpr auto data_multi_id_in_access_order =
reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{}); access_multi_id.Modify(Number<nDim - 1>{}, Number<itmp>{});
constexpr auto data_multi_id = reorder_array_given_old2new(
sequence2array(data_multi_id_in_access_order), DimAccessOrder{});
const index_t src_index = const index_t src_index =
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id); SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
...@@ -82,14 +85,12 @@ __device__ void threadwise_generic_tensor_slice_copy_v1( ...@@ -82,14 +85,12 @@ __device__ void threadwise_generic_tensor_slice_copy_v1(
*reinterpret_cast<const vector_t*>(&p_src[src_index]); *reinterpret_cast<const vector_t*>(&p_src[src_index]);
}); });
#else #else
static_ford<decltype(access_lengths)>{}([&](auto access_multi_id) { ford<decltype(access_lengths)>{}([&](auto access_multi_id) {
constexpr index_t itmp = access_multi_id.Back() * DataPerAccess; auto data_multi_id_in_access_order = access_multi_id;
data_multi_id_in_access_order(nDim - 1) = access_multi_id[nDim - 1] * DataPerAccess;
constexpr auto data_multi_id_in_access_order =
access_multi_id.Modify(Number<nDim - 1>{}, Number<itmp>{});
constexpr auto data_multi_id = reorder_array_given_old2new( const auto data_multi_id =
sequence2array(data_multi_id_in_access_order), DimAccessOrder{}); reorder_array_given_old2new(data_multi_id_in_access_order, DimAccessOrder{});
const index_t src_index = const index_t src_index =
SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id); SrcDesc::GetOffsetFromMultiIndex(src_multi_id_begin + data_multi_id);
......
...@@ -56,7 +56,7 @@ __device__ void threadwise_tensor_slice_copy(SrcDesc, ...@@ -56,7 +56,7 @@ __device__ void threadwise_tensor_slice_copy(SrcDesc,
static_ford<decltype(ref_desc.GetLengths().PopBack())>{}([=](auto Ids) { static_ford<decltype(ref_desc.GetLengths().PopBack())>{}([=](auto Ids) {
static_for<0, nRead, 1>{}([&](auto IRead) { static_for<0, nRead, 1>{}([&](auto IRead) {
constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead.Get() * DataPerRead>{}); constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead * DataPerRead>{});
const index_t src_index = src_desc.GetOffsetFromMultiIndex(multi_id); const index_t src_index = src_desc.GetOffsetFromMultiIndex(multi_id);
...@@ -177,8 +177,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc, ...@@ -177,8 +177,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
// pack data // pack data
static_for<0, DstDataPerWrite, 1>{}([&](auto IDstData) { static_for<0, DstDataPerWrite, 1>{}([&](auto IDstData) {
const auto dst_multi_id = const auto dst_multi_id = ids.PushBack(IWrite * DstDataPerWrite + IDstData);
ids.PushBack(IWrite.Get() * DstDataPerWrite + IDstData.Get());
const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{}); const auto src_multi_id = reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
...@@ -189,7 +188,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc, ...@@ -189,7 +188,7 @@ threadwise_tensor_slice_copy_reorder_given_dst2src_v3(SrcDesc,
}); });
// write data // write data
const auto dst_multi_id = ids.PushBack(IWrite.Get() * DstDataPerWrite); const auto dst_multi_id = ids.PushBack(IWrite * DstDataPerWrite);
const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id); const index_t dst_index = dst_desc.GetOffsetFromMultiIndex(dst_multi_id);
......
...@@ -98,7 +98,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData ...@@ -98,7 +98,7 @@ __host__ __device__ constexpr auto reorder_array_given_new2old(const Array<TData
{ {
static_assert(NSize == sizeof...(IRs), "NSize not consistent"); static_assert(NSize == sizeof...(IRs), "NSize not consistent");
static_assert(is_valid_sequence_map<Sequence<IRs...>>::value, "wrong! invalid reorder map"); static_assert(is_valid_sequence_map<Sequence<IRs...>>{}, "wrong! invalid reorder map");
return Array<TData, NSize>{old_array[IRs]...}; return Array<TData, NSize>{old_array[IRs]...};
} }
......
...@@ -55,22 +55,6 @@ struct Sequence ...@@ -55,22 +55,6 @@ struct Sequence
return Sequence<Type::Get(Number<IRs>{})...>{}; return Sequence<Type::Get(Number<IRs>{})...>{};
} }
#if 0 // require sequence_sort, which is not implemented yet
template <class MapOld2New>
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New /*old2new*/)
{
static_assert(sizeof...(Is) == MapOld2New::GetSize(),
"wrong! reorder map should have the same size as Sequence to be rerodered");
static_assert(is_valid_sequence_map<MapOld2New>::value,
"wrong! invalid reorder map");
constexpr auto map_new2old = typename sequence_map_inverse<MapOld2New>::SeqMapType{};
return ReorderGivenNew2Old(map_new2old);
}
#endif
__host__ __device__ static constexpr auto Reverse(); __host__ __device__ static constexpr auto Reverse();
__host__ __device__ static constexpr index_t Front() __host__ __device__ static constexpr index_t Front()
...@@ -263,74 +247,15 @@ struct sequence_reverse<Sequence<I0, I1>> ...@@ -263,74 +247,15 @@ struct sequence_reverse<Sequence<I0, I1>>
using SeqType = Sequence<I1, I0>; using SeqType = Sequence<I1, I0>;
}; };
#if 0 // not fully implemented
template <class KeySeq0, class ValSeq0, class KeySeq1, class ValSeq1>
struct sequence_sort_merge_impl;
template <index_t Key0,
index_t... Keys0,
index_t Val0,
index_t... Vals0,
index_t Key1,
index_t... Keys1,
index_t Val0,
index_t... Vals1>
struct sequence_sort_merge_impl<Sequence<Key0, Keys0...>,
Sequence<Val0, Vals0...>,
Sequence<Key1, Keys1...>,
Sequence<Val1, Vals1...>>
{
};
template <class>
struct sequence_sort;
template <index_t... Is>
struct sequence_sort<Sequence<Is...>>
{
using OriginalSeqType = Sequence<Is...>;
using SortedSeqType = xxxxx;
using MapSorted2OriginalType = xxx;
};
template <class Seq, class IsValidSeqMap>
struct sequence_map_inverse_impl;
// impl for valid map, no impl for invalid map
template <index_t... Is>
struct sequence_map_inverse_impl<Sequence<Is...>, true>
{
using SeqMapType = sequence_sort<Sequence<Is...>>::MapSorted2OriginalType;
};
template <class>
struct sequence_map_inverse;
template <class Is...>
struct sequence_map_inverse<Sequence<Is...>>
{
// TODO: make sure the map to be inversed is valid: [0, sizeof...(Is))
static constexpr bool is_valid_sequence_map =
is_same<typename sequence_sort<Sequence<Is...>>::SortedSeqType,
typename arithmetic_sequence_gen<0, sizeof...(Is), 1>::SeqType>::value;
// make compiler fails, if is_valid_map != true
using SeqMapType =
typename sequence_map_inverse_impl<Sequence<Is...>, is_valid_map>::SeqMapType;
};
#endif
template <class Seq> template <class Seq>
struct is_valid_sequence_map struct is_valid_sequence_map
{ {
static constexpr bool value = static constexpr bool value = true;
#if 0 // sequence_sort is not implemented yet
is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::SeqType, // TODO: add proper check for is_valid, something like:
typename sequence_sort<Seq>::SortedSeqType>::value; // static constexpr bool value =
#else // is_same<typename arithmetic_sequence_gen<0, Seq::GetSize(), 1>::SeqType,
true; // typename sequence_sort<Seq>::SortedSeqType>{};
#endif
}; };
template <index_t... Xs, index_t... Ys> template <index_t... Xs, index_t... Ys>
......
...@@ -3,91 +3,8 @@ ...@@ -3,91 +3,8 @@
#include "vector_type.hpp" #include "vector_type.hpp"
#define NO_VM_WAIT 0
#define NO_LGKM_WAIT 0
#define NO_DS_READ 0
#define NO_DS_WRITE 0
#define NO_GLB_READ 0
namespace ck { namespace ck {
// cast a pointer of LDS to its address
extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]];
__device__ void vmcnt(index_t cnt)
{
#if !NO_VM_WAIT
if(cnt == 0)
{
asm volatile("\n \
s_waitcnt vmcnt(0) \n \
" ::);
}
else if(cnt == 1)
{
asm volatile("\n \
s_waitcnt vmcnt(1) \n \
" ::);
}
else if(cnt == 2)
{
asm volatile("\n \
s_waitcnt vmcnt(2) \n \
" ::);
}
else if(cnt == 4)
{
asm volatile("\n \
s_waitcnt vmcnt(2) \n \
" ::);
}
else
{
assert(false);
}
#endif
}
__device__ void lgkmcnt(index_t cnt)
{
#if !NO_LGKM_WAIT
if(cnt == 0)
{
asm volatile("\n \
s_waitcnt lgkmcnt(0) \n \
" ::);
}
else if(cnt == 1)
{
asm volatile("\n \
s_waitcnt lgkmcnt(1) \n \
" ::);
}
else if(cnt == 2)
{
asm volatile("\n \
s_waitcnt lgkmcnt(2) \n \
" ::);
}
else if(cnt == 3)
{
asm volatile("\n \
s_waitcnt lgkmcnt(3) \n \
" ::);
}
else if(cnt == 4)
{
asm volatile("\n \
s_waitcnt lgkmcnt(4) \n \
" ::);
}
else
{
assert(false);
}
#endif
}
__device__ void outerProduct1x4(const float* a, const float* b, float* c) __device__ void outerProduct1x4(const float* a, const float* b, float* c)
{ {
asm volatile("\n \ asm volatile("\n \
...@@ -112,21 +29,7 @@ __device__ void outerProduct1x4(const float& a, ...@@ -112,21 +29,7 @@ __device__ void outerProduct1x4(const float& a,
const vector_type<float, 4>::MemoryType& b, const vector_type<float, 4>::MemoryType& b,
vector_type<float, 4>::MemoryType& c) vector_type<float, 4>::MemoryType& c)
{ {
#if 0
asm volatile(
"\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
:
:"v"(c.x),"v"(c.y),"v"(c.z),"v"(c.w), \
"v"(a.x),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
#else
outerProduct1x4(&a, (float*)&b, (float*)&c); outerProduct1x4(&a, (float*)&b, (float*)&c);
#endif
} }
__device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a, __device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
...@@ -136,57 +39,10 @@ __device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a, ...@@ -136,57 +39,10 @@ __device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
vector_type<float, 4>::MemoryType& c2, vector_type<float, 4>::MemoryType& c2,
vector_type<float, 4>::MemoryType& c3) vector_type<float, 4>::MemoryType& c3)
{ {
#if 0
asm volatile(
"\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
:
:"v"(c0.x),"v"(c0.y),"v"(c0.z),"v"(c0.w), \
"v"(a.x),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
asm volatile(
"\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
:
:"v"(c1.x),"v"(c1.y),"v"(c1.z),"v"(c1.w), \
"v"(a.y),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
asm volatile(
"\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
:
:"v"(c2.x),"v"(c2.y),"v"(c2.z),"v"(c2.w), \
"v"(a.z),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
asm volatile(
"\n \
v_mac_f32 %0, %4, %5 \n \
v_mac_f32 %1, %4, %6 \n \
v_mac_f32 %2, %4, %7 \n \
v_mac_f32 %3, %4, %8 \n \
"
:
:"v"(c3.x),"v"(c3.y),"v"(c3.z),"v"(c3.w), \
"v"(a.w),"v"(b.x),"v"(b.y),"v"(b.z),"v"(b.w)
);
#else
outerProduct1x4(a.x, b, c0); outerProduct1x4(a.x, b, c0);
outerProduct1x4(a.y, b, c1); outerProduct1x4(a.y, b, c1);
outerProduct1x4(a.z, b, c2); outerProduct1x4(a.z, b, c2);
outerProduct1x4(a.w, b, c3); outerProduct1x4(a.w, b, c3);
#endif
} }
__device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a, __device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a,
...@@ -201,7 +57,6 @@ __device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a, ...@@ -201,7 +57,6 @@ __device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a,
__device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0) __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0)
{ {
#if !NO_DS_READ
if(offset == 0) if(offset == 0)
{ {
asm volatile("\n \ asm volatile("\n \
...@@ -722,33 +577,11 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in ...@@ -722,33 +577,11 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
: "=v"(r) : "=v"(r)
: "v"(__to_local(lds))); : "v"(__to_local(lds)));
} }
#endif
}
__device__ void global_load(vector_type<float, 4>::MemoryType& r,
const vector_type<float, 4>::MemoryType* ptr,
index_t offset = 0)
{
#if !NO_GLB_READ
if(offset == 0)
{
asm volatile("\n \
global_load_dwordx4 %0, %1, off \n \
"
: "=v"(r)
: "v"(ptr));
}
else
{
assert(false);
}
#endif
} }
__device__ void __device__ void
ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0) ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0)
{ {
#if !NO_DS_WRITE
if(offset == 0) if(offset == 0)
{ {
asm volatile("\n \ asm volatile("\n \
...@@ -761,7 +594,6 @@ ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t off ...@@ -761,7 +594,6 @@ ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t off
{ {
assert(false); assert(false);
} }
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
#include "hip/hip_fp16.h" #include "hip/hip_fp16.h"
#define CK_USE_AMD_INLINE_ASM 1 #define CK_USE_AMD_INLINE_ASM 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
namespace ck { namespace ck {
// For some reason, HIP compiler need this definition to generate optimal load and store // For some reason, HIP compiler need this definition to generate optimal load and store
......
...@@ -9,6 +9,9 @@ ...@@ -9,6 +9,9 @@
#include "helper_cuda.h" #include "helper_cuda.h"
#define CK_USE_AMD_INLINE_ASM 0 #define CK_USE_AMD_INLINE_ASM 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1 0
namespace ck { namespace ck {
// For some reason, CUDA need this definition, otherwise // For some reason, CUDA need this definition, otherwise
......
...@@ -24,10 +24,8 @@ struct swallow ...@@ -24,10 +24,8 @@ struct swallow
}; };
// Emulate if constexpr // Emulate if constexpr
template <bool Predicate> template <bool>
struct static_if struct static_if;
{
};
template <> template <>
struct static_if<true> struct static_if<true>
......
#ifndef CK_INTEGRAL_CONSTANT_HPP #ifndef CK_INTEGRAL_CONSTANT_HPP
#define CK_INTEGRAL_CONSTANT_HPP #define CK_INTEGRAL_CONSTANT_HPP
namespace ck { #include <type_traits>
template <class T, T N> namespace ck {
struct integral_constant
{
static const T value = N;
__host__ __device__ constexpr T Get() const { return value; } template <class T, T v>
}; using integral_constant = std::integral_constant<T, v>;
template <class T, T X, T Y> template <class T, T X, T Y>
__host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_constant<T, Y>) __host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_constant<T, Y>)
...@@ -17,6 +14,12 @@ __host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_c ...@@ -17,6 +14,12 @@ __host__ __device__ constexpr auto operator+(integral_constant<T, X>, integral_c
return integral_constant<T, X + Y>{}; return integral_constant<T, X + Y>{};
} }
template <class T, T X, T Y>
__host__ __device__ constexpr auto operator*(integral_constant<T, X>, integral_constant<T, Y>)
{
return integral_constant<T, X * Y>{};
}
template <index_t N> template <index_t N>
using Number = integral_constant<index_t, N>; using Number = integral_constant<index_t, N>;
......
#ifndef CK_UTILITY_HPP #ifndef CK_UTILITY_HPP
#define CK_UTILITY_HPP #define CK_UTILITY_HPP
#include <type_traits>
#include "config.hpp" #include "config.hpp"
namespace ck { namespace ck {
...@@ -9,23 +10,8 @@ __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } ...@@ -9,23 +10,8 @@ __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
template <class T1, class T2>
struct is_same
{
static constexpr bool value = false;
};
template <class T>
struct is_same<T, T>
{
static constexpr bool value = true;
};
template <class X, class Y> template <class X, class Y>
__host__ __device__ constexpr bool is_same_type(X, Y) using is_same = std::is_same<X, Y>;
{
return is_same<X, Y>::value;
}
namespace math { namespace math {
...@@ -58,7 +44,7 @@ struct integer_divide_ceiler ...@@ -58,7 +44,7 @@ struct integer_divide_ceiler
{ {
__host__ __device__ constexpr T operator()(T a, T b) const __host__ __device__ constexpr T operator()(T a, T b) const
{ {
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type"); static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
return (a + b - 1) / b; return (a + b - 1) / b;
} }
...@@ -67,7 +53,7 @@ struct integer_divide_ceiler ...@@ -67,7 +53,7 @@ struct integer_divide_ceiler
template <class T> template <class T>
__host__ __device__ constexpr T integer_divide_ceil(T a, T b) __host__ __device__ constexpr T integer_divide_ceil(T a, T b)
{ {
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type"); static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
return (a + b - 1) / b; return (a + b - 1) / b;
} }
...@@ -85,7 +71,7 @@ __host__ __device__ constexpr T max(T x, Ts... xs) ...@@ -85,7 +71,7 @@ __host__ __device__ constexpr T max(T x, Ts... xs)
auto y = max(xs...); auto y = max(xs...);
static_assert(is_same<decltype(y), T>::value, "not the same type"); static_assert(is_same<decltype(y), T>{}, "not the same type");
return x > y ? x : y; return x > y ? x : y;
} }
...@@ -103,12 +89,12 @@ __host__ __device__ constexpr T min(T x, Ts... xs) ...@@ -103,12 +89,12 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
auto y = min(xs...); auto y = min(xs...);
static_assert(is_same<decltype(y), T>::value, "not the same type"); static_assert(is_same<decltype(y), T>{}, "not the same type");
return x < y ? x : y; return x < y ? x : y;
} }
// this is wrong // this is WRONG
// TODO: implement least common multiple properly, instead of calling max() // TODO: implement least common multiple properly, instead of calling max()
template <class T, class... Ts> template <class T, class... Ts>
__host__ __device__ constexpr T lcm(T x, Ts... xs) __host__ __device__ constexpr T lcm(T x, Ts... xs)
......
...@@ -64,131 +64,6 @@ struct vector_type<float, 4> ...@@ -64,131 +64,6 @@ struct vector_type<float, 4>
} }
}; };
#if 0
template <>
struct vector_type<half, 1>
{
using MemoryType = half;
__host__ __device__ static MemoryType Pack(half s) { return s; }
};
template <>
struct vector_type<half, 2>
{
using MemoryType = half2;
__host__ __device__ static MemoryType Pack(half s0, half s1)
{
union
{
MemoryType vector;
half scalar[2];
} data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
struct vector_type<half, 4>
{
using MemoryType = float2;
};
template <>
struct vector_type<half, 8>
{
using MemoryType = float4;
};
template <>
struct vector_type<char, 1>
{
using MemoryType = char;
__host__ __device__ static MemoryType Pack(char s) { return s; }
};
template <>
struct vector_type<char, 2>
{
using MemoryType = int16_t;
__host__ __device__ static MemoryType Pack(char s0, char s1)
{
union
{
MemoryType vector;
char scalar[2];
} data;
data.scalar[0] = s0;
data.scalar[1] = s1;
return data.vector;
}
};
template <>
struct vector_type<char, 4>
{
using MemoryType = int32_t;
__host__ __device__ static MemoryType Pack(char s0, char s1, char s2, char s3)
{
union
{
MemoryType vector;
char scalar[4];
} data;
data.scalar[0] = s0;
data.scalar[1] = s1;
data.scalar[2] = s2;
data.scalar[3] = s3;
return data.vector;
}
};
template <>
struct vector_type<char, 8>
{
using MemoryType = int64_t;
};
template <>
struct vector_type<int32_t, 2>
{
using MemoryType = int64_t;
};
template <>
struct vector_type<char2, 2>
{
using MemoryType = char4;
};
template <>
struct vector_type<char2, 4>
{
using MemoryType = int64_t;
};
template <>
struct vector_type<char4, 1>
{
using MemoryType = int;
};
template <>
struct vector_type<char4, 2>
{
using MemoryType = int64_t;
};
#endif
} // namespace ck } // namespace ck
#endif #endif
...@@ -46,7 +46,7 @@ auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>) ...@@ -46,7 +46,7 @@ auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
template <class F, class T> template <class F, class T>
auto call_f_unpack_args(F f, T args) auto call_f_unpack_args(F f, T args)
{ {
constexpr std::size_t N = std::tuple_size<T>::value; constexpr std::size_t N = std::tuple_size<T>{};
return call_f_unpack_args_impl(f, args, std::make_index_sequence<N>{}); return call_f_unpack_args_impl(f, args, std::make_index_sequence<N>{});
} }
...@@ -60,7 +60,7 @@ auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>) ...@@ -60,7 +60,7 @@ auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
template <class F, class T> template <class F, class T>
auto construct_f_unpack_args(F, T args) auto construct_f_unpack_args(F, T args)
{ {
constexpr std::size_t N = std::tuple_size<T>::value; constexpr std::size_t N = std::tuple_size<T>{};
return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{}); return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment