Commit bf111ac6 authored by Jing Zhang's avatar Jing Zhang
Browse files

preload

parent 9728f541
...@@ -264,6 +264,72 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -264,6 +264,72 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
__syncthreads(); __syncthreads();
#if 1
constexpr auto KPerThreadAdd = KPerThread / CThreadTransferDstScalarPerVector;
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
constexpr auto d_k_n_hox2_wox2_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThreadAdd>{},
Number<1>{},
Number<HoPerThreadx2>{},
Number<WoPerThreadx2>{}));
constexpr auto vec_len =
d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize() * CThreadTransferDstScalarPerVector;
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
const index_t hox2_block_data_on_global = ho_block_work_id * HoPerBlock * 2;
const index_t wox2_block_data_on_global = wo_block_work_id * WoPerBlock * 2;
const index_t hox2_thread_data_on_global =
hox2_block_data_on_global + ho_thread_id * HoPerThreadx2;
const index_t wox2_thread_data_on_global =
wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
static_assert(KPerThread % CThreadTransferDstScalarPerVector == 0, "");
static_assert(CThreadTransferDstScalarPerVector == 16, "");
const index_t k_block_data_on_global_add =
k_block_work_id * KPerBlock / CThreadTransferDstScalarPerVector;
const index_t k_thread_data_on_global_add =
k_block_data_on_global_add + k_thread_id * KPerThreadAdd;
static_assert(vec_len == 256, "");
vector_type<int8_t, vec_len> d_vec;
{
ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC,
decltype(d_vec),
decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
1,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add,
0,
hox2_thread_data_on_global,
wox2_thread_data_on_global))
.Run(d_k_n_hox2_wox2_global_desc,
p_d_global,
d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
d_vec,
c_k_n_ho_wo_global_tensor_iterator_hacks);
}
#endif
index_t b_block_data_begin = 0; index_t b_block_data_begin = 0;
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
...@@ -452,66 +518,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -452,66 +518,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
} }
#else #else
{ {
constexpr auto HoPerThreadx2 = HoPerThread * 2;
constexpr auto WoPerThreadx2 = WoPerThread * 2;
const index_t hox2_block_data_on_global = ho_block_work_id * HoPerBlock * 2;
const index_t wox2_block_data_on_global = wo_block_work_id * WoPerBlock * 2;
const index_t hox2_thread_data_on_global =
hox2_block_data_on_global + ho_thread_id * HoPerThreadx2;
const index_t wox2_thread_data_on_global =
wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
static_assert(KPerThread % CThreadTransferDstScalarPerVector == 0, "");
static_assert(CThreadTransferDstScalarPerVector == 16, "");
constexpr auto KPerThreadAdd = KPerThread / CThreadTransferDstScalarPerVector;
const index_t k_block_data_on_global_add =
k_block_work_id * KPerBlock / CThreadTransferDstScalarPerVector;
const index_t k_thread_data_on_global_add =
k_block_data_on_global_add + k_thread_id * KPerThreadAdd;
constexpr auto d_k_n_hox2_wox2_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number<KPerThreadAdd>{},
Number<1>{},
Number<HoPerThreadx2>{},
Number<WoPerThreadx2>{}));
constexpr auto vec_len = d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize() *
CThreadTransferDstScalarPerVector;
static_assert(vec_len == 256, "");
vector_type<int8_t, vec_len> d_vec;
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC,
decltype(d_vec),
decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
// CThreadTransferDstScalarPerVector,
1,
AddressSpace::Global,
AddressSpace::Vgpr,
InMemoryDataOperation::Set,
1,
true>(d_k_n_hox2_wox2_global_desc,
make_multi_index(k_thread_data_on_global_add,
0,
hox2_thread_data_on_global,
wox2_thread_data_on_global))
.Run(d_k_n_hox2_wox2_global_desc,
p_d_global,
d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0),
d_vec,
c_k_n_ho_wo_global_tensor_iterator_hacks);
static_for<0, KPerThreadAdd, 1>{}([&](auto k_i) { static_for<0, KPerThreadAdd, 1>{}([&](auto k_i) {
static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) { static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) {
static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) { static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) {
......
...@@ -28,10 +28,10 @@ ...@@ -28,10 +28,10 @@
#endif #endif
// launch bounds // launch bounds
#define CK_USE_LAUNCH_BOUNDS 0 #define CK_USE_LAUNCH_BOUNDS 1
#ifdef CK_USE_LAUNCH_BOUNDS #ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256 #define CK_MAX_THREAD_PER_BLOCK 64
#define CK_MIN_BLOCK_PER_CU 1 #define CK_MIN_BLOCK_PER_CU 1
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment