"tests/test_layers_utils.py" did not exist on "48269070d23ad8a4c6f31bc6847c358aac182ad1"
Commit ad7d9460 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean code

parent 0e221501
...@@ -156,12 +156,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -156,12 +156,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K>{}; ABlockTransferDstScalarPerVector_K>{};
// register allocation for output
FloatAcc p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
// zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const auto k_thread_id = c_thread_mtx_index.k; const auto k_thread_id = c_thread_mtx_index.k;
...@@ -229,6 +223,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -229,6 +223,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
FloatAB* p_a_block = p_shared_block; FloatAB* p_a_block = p_shared_block;
// register allocation for output
FloatAcc p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()];
// zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0);
// hack to control index calculation when iterating over A and B matrix for threadwise copy // hack to control index calculation when iterating over A and B matrix for threadwise copy
...@@ -351,9 +351,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2 ...@@ -351,9 +351,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v2
// output: register to global memory // output: register to global memory
{ {
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
static_assert(CThreadTransferDstScalarPerVector == 16 && KPerBlock == 16, "");
const index_t k_block_data_on_global_vec = const index_t k_block_data_on_global_vec =
k_block_work_id * (KPerBlock / CThreadTransferDstScalarPerVector); k_block_work_id * (KPerBlock / CThreadTransferDstScalarPerVector);
const index_t KPerThreadVec = KPerThread / CThreadTransferDstScalarPerVector; const index_t KPerThreadVec = KPerThread / CThreadTransferDstScalarPerVector;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment