Commit 0e221501 authored by Jing Zhang's avatar Jing Zhang
Browse files

tweak

parent fb04c9be
......@@ -366,8 +366,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
Number<HoPerThread>{},
Number<WoPerThread>{}));
static_assert(c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize() == 4, "");
const index_t vec_len = c_k_n_ho_wo_thread_desc_vec.GetElementSpaceSize() *
CThreadTransferDstScalarPerVector;
......
......@@ -31,7 +31,7 @@
#define CK_USE_LAUNCH_BOUNDS 0
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MAX_THREAD_PER_BLOCK 64
#define CK_MIN_BLOCK_PER_CU 1
#endif
......
......@@ -469,7 +469,71 @@ struct vector_type<T, 64>
}
};
template <typename T>
struct vector_type<T, 128>
{
using d1_t = T;
typedef T d128_t __attribute__((ext_vector_type(128)));
typedef T d16_t __attribute__((ext_vector_type(16)));
using type = d128_t;
union
{
d128_t d128_;
StaticallyIndexedArray<d1_t, 128> d1x128_;
StaticallyIndexedArray<d16_t, 8> d16x8_;
StaticallyIndexedArray<d128_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value ||
is_same<X, d16_t>::value ||
is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value ||
is_same<X, d16_t>::value ||
is_same<X, d128_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
}
};
template <typename T>
struct vector_type<T, 256>
......
......@@ -126,8 +126,8 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
constexpr index_t WoPerThread = 2;
constexpr index_t EPerThread = EPerBlock;
using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<EPerBlock, KPerBlock>;
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment