Commit bced61c2 authored by Anthony Chang's avatar Anthony Chang
Browse files

amend

parent 9176cd6b
#pragma once
#include "common_header.hpp"
#include "tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
namespace ck {
......@@ -341,8 +342,13 @@ struct GridwiseGemmPipelineInterwave_v1<1>
}
};
template <index_t NumPrefetch,
bool HasMainLoop>
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template <>
struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2>
{
};
template <index_t NumPrefetch, LoopScheduler LoopSched>
constexpr auto GridwiseGemmPipeline_v1_Selector()
{
if constexpr(LoopSched == LoopScheduler::Default)
......
......@@ -511,21 +511,21 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
......
......@@ -454,21 +454,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// shuffle C and write out
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment