Commit 5469c613 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

[WIP] Add support for double buffering in direct load GEMM kernel

parent 627054b9
...@@ -134,6 +134,9 @@ ...@@ -134,6 +134,9 @@
// inner product using V_DOT with DPP8 modifiers // inner product using V_DOT with DPP8 modifiers
#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 #define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
...@@ -380,7 +380,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -380,7 +380,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
<< " LoopScheduler: " << " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: " << "PipelineVersion: "
<< PipelineVersionToString[PipelineVer]; << PipelineVersionToString[PipelineVer] << ", "
<< "Prefetch: "
<< NumGemmKPrefetchStage;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -236,9 +236,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -236,9 +236,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto c_block_size = constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + return math::max(
b_block_space_size_aligned * sizeof(BComputeDataType), NumGemmKPrefetchStage * a_block_space_size_aligned * sizeof(AComputeDataType) +
c_block_size * sizeof(CShuffleDataType)); NumGemmKPrefetchStage * b_block_space_size_aligned * sizeof(BComputeDataType),
c_block_size * sizeof(CShuffleDataType));
} }
__host__ __device__ static auto __host__ __device__ static auto
...@@ -491,6 +492,22 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -491,6 +492,22 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <typename DataType>
__device__ static auto AllocateBlockBuffers(void* p_shared,
int32_t num_elems,
int32_t offset_elems,
int32_t max_lds_align)
{
const int32_t single_buffer_offset = math::integer_least_multiple(num_elems, max_lds_align);
return generate_tuple(
[&](auto i) {
const int32_t local_offset = i * single_buffer_offset;
return make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + local_offset + offset_elems, num_elems);
},
Number<NumGemmKPrefetchStage>{});
}
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1, typename BGridDesc_BK0_N_BK1,
...@@ -624,12 +641,14 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -624,12 +641,14 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buffers = AllocateBlockBuffers<AComputeDataType>(
static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); p_shared, a_block_desc_ak0_m_ak1.GetElementSpaceSize(), 0, max_lds_align);
const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage;
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buffers =
static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned, AllocateBlockBuffers<BComputeDataType>(p_shared,
b_block_desc_bk0_n_bk1.GetElementSpaceSize()); b_block_desc_bk0_n_bk1.GetElementSpaceSize(),
b_buffers_offset,
max_lds_align);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
...@@ -645,13 +664,13 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -645,13 +664,13 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
a_blockwise_copy, a_blockwise_copy,
a_grid_buf, a_grid_buf,
a_block_buf, a_block_buffers,
a_block_slice_copy_step, a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
b_blockwise_copy, b_blockwise_copy,
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buffers,
b_block_slice_copy_step, b_block_slice_copy_step,
blockwise_gemm, blockwise_gemm,
c_thread_buf, c_thread_buf,
......
...@@ -17,7 +17,6 @@ template <> ...@@ -17,7 +17,6 @@ template <>
struct GridwiseGemmPipeline_v4<1> struct GridwiseGemmPipeline_v4<1>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
...@@ -31,13 +30,13 @@ struct GridwiseGemmPipeline_v4<1> ...@@ -31,13 +30,13 @@ struct GridwiseGemmPipeline_v4<1>
typename ABlockDesc, typename ABlockDesc,
typename ABlockTransfer, typename ABlockTransfer,
typename AGridBuffer, typename AGridBuffer,
typename ABlockBuffer, typename ABlockBuffers,
typename ABlockTransferStep, typename ABlockTransferStep,
typename BGridDesc, typename BGridDesc,
typename BBlockDesc, typename BBlockDesc,
typename BBlockTransfer, typename BBlockTransfer,
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer, typename BBlockBuffers,
typename BBlockTransferStep, typename BBlockTransferStep,
typename BlockwiseGemm, typename BlockwiseGemm,
typename CThreadBuffer> typename CThreadBuffer>
...@@ -45,18 +44,22 @@ struct GridwiseGemmPipeline_v4<1> ...@@ -45,18 +44,22 @@ struct GridwiseGemmPipeline_v4<1>
const ABlockDesc& a_block_desc, const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy, ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf, const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf, ABlockBuffers& a_block_bufs,
const ABlockTransferStep& a_block_copy_step, const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc, const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc, const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy, BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf, BBlockBuffers& b_block_bufs,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm, const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) index_t num_loop)
{ {
static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1);
auto& a_block_buf = a_block_bufs.At(I0);
auto& b_block_buf = b_block_bufs.At(I0);
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
...@@ -98,4 +101,116 @@ struct GridwiseGemmPipeline_v4<1> ...@@ -98,4 +101,116 @@ struct GridwiseGemmPipeline_v4<1>
} }
}; };
// 2-stages prefetch
template <>
struct GridwiseGemmPipeline_v4<2>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t num_loop)
{
return num_loop % 2 == 0;
}
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return (num_loop / 2) > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffers,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffers,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffers& a_block_bufs,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffers& b_block_bufs,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2);
auto& a_block_buf1 = a_block_bufs.At(I0);
auto& a_block_buf2 = a_block_bufs.At(I1);
auto& b_block_buf1 = b_block_bufs.At(I0);
auto& b_block_buf2 = b_block_bufs.At(I1);
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds_direct_load();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf);
block_sync_lds_direct_load();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf);
i += 2;
} while(i < (num_loop - 2));
}
// tail
{
block_sync_lds_direct_load();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf);
block_sync_lds_direct_load();
blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf);
}
}
};
} // namespace ck } // namespace ck
...@@ -972,6 +972,16 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, ...@@ -972,6 +972,16 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T* lds_ptr = lds_base_ptr + lds_offset;
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
asm volatile("s_mov_b32 m0, %0; \n\t"
"s_nop 0;\n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes),
"s"(src_resource));
#else
// LDS pointer must be attributed with the LDS address space. // LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr = __attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
...@@ -979,6 +989,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, ...@@ -979,6 +989,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
llvm_amdgcn_raw_buffer_load_lds( llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
#endif
} }
} // namespace ck } // namespace ck
...@@ -26,6 +26,12 @@ __device__ void block_sync_lds_direct_load() ...@@ -26,6 +26,12 @@ __device__ void block_sync_lds_direct_load()
s_waitcnt lgkmcnt(0) \n \ s_waitcnt lgkmcnt(0) \n \
s_barrier \ s_barrier \
" ::); " ::);
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
// When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of
// `sched_barrier` is necessary to make sure that no instructions that use the loaded memory
// are scheduled by the compiler before the `waitcnt` instruction.
__builtin_amdgcn_sched_barrier(0);
#endif
} }
__device__ void s_nop() __device__ void s_nop()
......
...@@ -35,7 +35,26 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = ...@@ -35,7 +35,26 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances =
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4> DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 32, 128, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<1, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on // clang-format on
>; >;
......
...@@ -34,6 +34,14 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances = ...@@ -34,6 +34,14 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances =
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>, DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 16, 4, 4, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 64, 8, 8, 32, 32, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4> DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on // clang-format on
>; >;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment