Commit 91d13ef4 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Add support for fp16

parent 5f4c1ddb
...@@ -59,6 +59,9 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) ...@@ -59,6 +59,9 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
if(GPU_TARGETS MATCHES "gfx90a") if(GPU_TARGETS MATCHES "gfx90a")
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp) add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16)
endif() endif()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "common.hpp"
#define USING_DIRECT_LOADS 1
#if USING_DIRECT_LOADS
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#endif
using F16 = ck::half_t;
using F32 = float;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
#if USING_DIRECT_LOADS
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
#else
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
#endif
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -38,8 +38,13 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -38,8 +38,13 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
static constexpr auto block_slice_lengths = BlockSliceLengths{}; static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_cluster_lengths;
static constexpr auto thread_steps = thread_cluster_lengths; static constexpr auto thread_single_load_size = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes whole wavefront load contiguous memory, what is required for direct loads.
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
__device__ constexpr ThreadGroupTensorSliceTransfer_DirectLoad( __device__ constexpr ThreadGroupTensorSliceTransfer_DirectLoad(
const SrcDesc& src_desc, const SrcDesc& src_desc,
...@@ -51,18 +56,29 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -51,18 +56,29 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
static_assert(NumLdsBuffers == 1, static_assert(NumLdsBuffers == 1,
"Direct load transfer does not support multiple LDS buffers."); "Direct load transfer does not support multiple LDS buffers.");
static_assert(ScalarPerVector == 1, static_assert(ck::is_same_v<SrcData, DstData>,
"Direct load transfer does not support vectorized transfers."); "Direct load transfer does not support datatypes conversion. Source and "
"destination data types must be the same.");
static_assert(
DstVectorDim == nDim - 1,
"Direct load transfer requires the destination vector dimension to be the last one.");
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
static_assert(bytes_per_thread_load == dword_bytes,
"Direct load transfer requires each thread to load exactly a single "
"DWORD of data.");
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() && static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() && nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size(), nDim == ThreadClusterLengths::Size(),
"Inconsistent number of dimensions across lengths and descriptors."); "Inconsistent number of dimensions across lengths and descriptors.");
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"Threads must be mapped to cover entire slicing window.");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"The number of threads cannot be less than the number of elements in " "The number of threads cannot be less than the number of elements in "
"thread cluster lengths."); "thread cluster lengths.");
...@@ -70,7 +86,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -70,7 +86,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const auto thread_cluster_idx = const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx; const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin); SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
...@@ -105,10 +121,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -105,10 +121,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"Destination data must be stored in an LDS memory buffer."); "Destination data must be stored in an LDS memory buffer.");
static_assert( static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value, ck::is_same_v<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>,
"SrcBuffer and SrcData data types must be consistent."); "SrcBuffer and SrcData data types must be consistent.");
static_assert( static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value, ck::is_same_v<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>,
"DstBuffer and DstData data types must be consistent."); "DstBuffer and DstData data types must be consistent.");
constexpr auto dst_access_lengths = thread_slice_lengths; constexpr auto dst_access_lengths = thread_slice_lengths;
...@@ -127,7 +143,8 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -127,7 +143,8 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_buf.CopyTo(dst_buf, src_offset, dst_offset, is_src_valid); src_buf.template CopyTo<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
dst_buf, src_offset, dst_offset, is_src_valid);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
{ {
......
...@@ -947,14 +947,14 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -947,14 +947,14 @@ amd_buffer_atomic_max(const typename vector_type_maker<T, N>::type::type src_thr
// Direct loads from global to LDS. // Direct loads from global to LDS.
__device__ void __device__ void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) float* lds_ptr, __attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size, index_t size,
index_t voffset, index_t voffset,
index_t soffset, index_t soffset,
index_t offset, index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T> template <typename T, index_t NumElemsPerThread>
__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset, const index_t global_offset,
T* lds_base_ptr, T* lds_base_ptr,
...@@ -963,17 +963,22 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, ...@@ -963,17 +963,22 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t src_element_space_size) const index_t src_element_space_size)
{ {
// Direct loads require that each thread writes a single DWORD. // Direct loads require that each thread writes a single DWORD.
static_assert(sizeof(T) == 4); constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
const int32x4_t src_resource = const uint32_t* global_ptr =
make_wave_buffer_resource(global_base_ptr, src_element_space_size); reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
// LDS pointer must be attributed with the LDS address space. // LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) T* lds_ptr = __attribute__((address_space(3))) uint32_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) T*>( reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset)); reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
llvm_amdgcn_raw_buffer_load_lds(src_resource, lds_ptr, sizeof(T), global_offset_bytes, 0, 0, 0); llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
} }
} // namespace ck } // namespace ck
...@@ -173,7 +173,7 @@ struct DynamicBuffer ...@@ -173,7 +173,7 @@ struct DynamicBuffer
} }
} }
template <typename DstBuffer> template <typename DstBuffer, index_t NumElemsPerThread>
__host__ __device__ void __host__ __device__ void
CopyTo(DstBuffer& dst_buf, index_t src_offset, index_t dst_offset, bool is_valid_element) const CopyTo(DstBuffer& dst_buf, index_t src_offset, index_t dst_offset, bool is_valid_element) const
{ {
...@@ -182,12 +182,12 @@ struct DynamicBuffer ...@@ -182,12 +182,12 @@ struct DynamicBuffer
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"Destination data must be stored in an LDS memory buffer."); "Destination data must be stored in an LDS memory buffer.");
amd_direct_load_global_to_lds(p_data_, amd_direct_load_global_to_lds<T, NumElemsPerThread>(p_data_,
src_offset, src_offset,
dst_buf.p_data_, dst_buf.p_data_,
dst_offset, dst_offset,
is_valid_element, is_valid_element,
element_space_size_); element_space_size_);
} }
template <typename X, template <typename X,
......
...@@ -227,6 +227,10 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( ...@@ -227,6 +227,10 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances); instances);
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif #endif
#ifdef CK_ENABLE_BF16 #ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
...@@ -456,6 +460,8 @@ struct DeviceOperationInstanceFactory< ...@@ -456,6 +460,8 @@ struct DeviceOperationInstanceFactory<
#endif #endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
op_ptrs);
} }
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>) is_same_v<CLayout, Row>)
......
...@@ -42,6 +42,7 @@ list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp ...@@ -42,6 +42,7 @@ list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
// ##################################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on
>;
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances)
{
add_device_operation_instances(
instances, device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment