Commit 0e221501 authored by Jing Zhang's avatar Jing Zhang
Browse files

tweak

parent fb04c9be
...@@ -366,8 +366,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -366,8 +366,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Number<HoPerThread>{}, Number<HoPerThread>{},
Number<WoPerThread>{})); Number<WoPerThread>{}));
static_assert(c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize() == 4, "");
const index_t vec_len = c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize() * const index_t vec_len = c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize() *
CThreadTransferDstScalarPerVector; CThreadTransferDstScalarPerVector;
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
#define CK_USE_LAUNCH_BOUNDS 0 #define CK_USE_LAUNCH_BOUNDS 0
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 64
#define CK_MIN_BLOCK_PER_CU 1 #define CK_MIN_BLOCK_PER_CU 1
#endif #endif
......
...@@ -469,7 +469,71 @@ struct vector_type<T, 64> ...@@ -469,7 +469,71 @@ struct vector_type<T, 64>
} }
}; };
template <typename T>
struct vector_type<T, 128>
{
using d1_t = T;
typedef T d128_t __attribute__((ext_vector_type(128)));
typedef T d16_t __attribute__((ext_vector_type(16)));
using type = d128_t;
union
{
d128_t d128_;
StaticallyIndexedArray<d1_t, 128> d1x128_;
StaticallyIndexedArray<d16_t, 8> d16x8_;
StaticallyIndexedArray<d128_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value ||
is_same<X, d16_t>::value ||
is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value ||
is_same<X, d16_t>::value ||
is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
};
template <typename T> template <typename T>
struct vector_type<T, 256> struct vector_type<T, 256>
......
...@@ -126,8 +126,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( ...@@ -126,8 +126,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t WoPerThread = 2; constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock; constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>; using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, KPerBlock>; using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1; constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment