Commit a476b4ba authored by Jing Zhang's avatar Jing Zhang
Browse files

debuging with array

parent 0924d5e5
...@@ -466,7 +466,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -466,7 +466,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
wox2_block_data_on_global + wo_thread_id * WoPerThreadx2; wox2_block_data_on_global + wo_thread_id * WoPerThreadx2;
static_assert(KPerThread % CThreadTransferDstScalarPerVector == 0, ""); static_assert(KPerThread % CThreadTransferDstScalarPerVector == 0, "");
// static_assert(CThreadTransferDstScalarPerVector == 16, ""); static_assert(CThreadTransferDstScalarPerVector == 16, "");
constexpr auto KPerThreadAdd = KPerThread / CThreadTransferDstScalarPerVector; constexpr auto KPerThreadAdd = KPerThread / CThreadTransferDstScalarPerVector;
const index_t k_block_data_on_global_add = const index_t k_block_data_on_global_add =
...@@ -480,16 +480,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -480,16 +480,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
Number<HoPerThreadx2>{}, Number<HoPerThreadx2>{},
Number<WoPerThreadx2>{})); Number<WoPerThreadx2>{}));
constexpr auto vector_len = d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize() * FloatC d_vec[d_k_n_hox2_wox2_thread_desc.GetElementSpaceSize()];
CThreadTransferDstScalarPerVector;
vector_type<int8_t, vector_len> d_vec;
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
ThreadwiseDynamicTensorSliceTransfer_v2< ThreadwiseDynamicTensorSliceTransfer_v2<
FloatC, FloatC,
decltype(d_vec), FloatC,
decltype(d_k_n_hox2_wox2_global_desc), decltype(d_k_n_hox2_wox2_global_desc),
decltype(d_k_n_hox2_wox2_thread_desc), decltype(d_k_n_hox2_wox2_thread_desc),
Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>, Sequence<KPerThreadAdd, 1, HoPerThreadx2, WoPerThreadx2>,
...@@ -506,30 +503,42 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -506,30 +503,42 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
0, 0,
hox2_thread_data_on_global, hox2_thread_data_on_global,
wox2_thread_data_on_global)) wox2_thread_data_on_global))
.Run2(d_k_n_hox2_wox2_global_desc, .Run(d_k_n_hox2_wox2_global_desc,
p_d_global, p_d_global,
d_k_n_hox2_wox2_thread_desc, d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
d_vec, d_vec,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
static_assert(vector_len == 256, ""); for(index_t k_i = 0; k_i < KPerThreadAdd; ++k_i)
{
static_for<0, vector_len, 1>{}([&](auto i) { for(index_t h_i = 0; h_i < HoPerThreadx2; ++h_i)
constexpr auto kpack_i = i % (CThreadTransferDstScalarPerVector); {
constexpr auto khw_i = i / (CThreadTransferDstScalarPerVector); for(index_t w_i = 0; w_i < WoPerThreadx2; ++w_i)
constexpr auto k_i = khw_i / (HoPerThreadx2 * WoPerThreadx2); {
constexpr auto hw_i = khw_i % (HoPerThreadx2 * WoPerThreadx2); vector_type<int8_t, CThreadTransferDstScalarPerVector> t;
constexpr auto h_i = hw_i / WoPerThreadx2;
constexpr auto w_i = hw_i % WoPerThreadx2; t.template AsType<FloatC>()(Number<0>{}) =
d_vec[d_k_n_hox2_wox2_thread_desc.CalculateOffset(
d_vec.template AsType<int8_t>()(i) = make_tuple(k_i, 0, h_i, w_i))];
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(make_tuple(
k_i * CThreadTransferDstScalarPerVector + kpack_i, 0, h_i / 2, w_i / 2))]; static_for<0, CThreadTransferDstScalarPerVector, 1>{}([&](auto i) {
}); t.template AsType<int8_t>()(i) +=
p_c_thread[c_k_n_ho_wo_thread_desc.CalculateOffset(
make_tuple(k_i * CThreadTransferDstScalarPerVector + i,
0,
h_i / 2,
w_i / 2))];
});
d_vec[d_k_n_hox2_wox2_thread_desc.CalculateOffset(make_tuple(
k_i, 0, h_i, w_i))] = t.template AsType<FloatC>()[Number<0>{}];
}
}
}
ThreadwiseDynamicTensorSliceTransfer_v1r3< ThreadwiseDynamicTensorSliceTransfer_v1r3<
decltype(d_vec), FloatC,
FloatC, FloatC,
decltype(d_k_n_hox2_wox2_thread_desc), decltype(d_k_n_hox2_wox2_thread_desc),
decltype(d_k_n_hox2_wox2_global_desc), decltype(d_k_n_hox2_wox2_global_desc),
...@@ -547,12 +556,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3 ...@@ -547,12 +556,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
0, 0,
hox2_thread_data_on_global, hox2_thread_data_on_global,
wox2_thread_data_on_global)) wox2_thread_data_on_global))
.Run2(d_k_n_hox2_wox2_thread_desc, .Run(d_k_n_hox2_wox2_thread_desc,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
d_vec, d_vec,
d_k_n_hox2_wox2_global_desc, d_k_n_hox2_wox2_global_desc,
p_c_global, p_c_global,
c_k_n_ho_wo_global_tensor_iterator_hacks); c_k_n_ho_wo_global_tensor_iterator_hacks);
} }
#endif #endif
} }
......
...@@ -377,8 +377,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -377,8 +377,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx +
i * dst_scalar_step_in_vector); i * dst_scalar_step_in_vector);
dst_vector.template AsType<DstData>()(i) = dst_vector.template AsType<DstData>()(i) = p_src.template AsType<DstData>()[i];
type_convert<DstData>{}(p_src.template AsType<DstData>()[i]);
}); });
const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment