Commit c1129f38 authored by Adam Osewski's avatar Adam Osewski
Browse files

Use KBatch to calculate K0

parent b66c4cc7
...@@ -265,7 +265,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -265,7 +265,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const index_t m_padded = GridwiseGemm::CalculateMPadded(M); const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N); const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const index_t k_padded = GridwiseGemm::CalculateKPadded(K); const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH); const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH);
const auto c_grid_desc_m_n = const auto c_grid_desc_m_n =
...@@ -319,7 +319,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -319,7 +319,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto& karg = gemm_kernel_args_[i].karg_; auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K); const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH); const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH);
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N( const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment