Commit 068fb458 authored by liuhy's avatar liuhy
Browse files

修改ck代码适配gfx926

parent acd8b8ea
...@@ -9,7 +9,10 @@ cd client_example/build ...@@ -9,7 +9,10 @@ cd client_example/build
``` ```
```bash ```bash
cmake -D CMAKE_CXX_COMPILER=${ROCM_PATH}/bin/hipcc -D CMAKE_PREFIX_PATH="${ROCM_PATH};${PATH_TO_CK_INSTALL_DIRECTORY}" .. cmake \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \
..
``` ```
### Build client example ### Build client example
......
...@@ -27,7 +27,6 @@ message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLA ...@@ -27,7 +27,6 @@ message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLA
FetchContent_Declare( FetchContent_Declare(
googletest googletest
GIT_REPOSITORY http://10.0.50.24/Mirrors/googletest.git GIT_REPOSITORY http://10.0.50.24/Mirrors/googletest.git
# GIT_REPOSITORY /work/home/zhangshao/installer/googletest
GIT_TAG b85864c64758dec007208e56af933fc3f52044ee GIT_TAG b85864c64758dec007208e56af933fc3f52044ee
) )
......
...@@ -7,7 +7,7 @@ cmake ...@@ -7,7 +7,7 @@ cmake
-D CMAKE_CXX_FLAGS="-O3" \ -D CMAKE_CXX_FLAGS="-O3" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx906;gfx926" \ -D GPU_TARGETS="gfx906;gfx926" \
-D CMAKE_INSTALL_PREFIX=~/composable_kernel/install_ck \ -D CMAKE_INSTALL_PREFIX=~/composable_kernel-develop/install_ck \
.. ..
cd - cd -
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
// buffer resource // buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1 #define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)||defined(__gfx926__) || defined(__gfx908__) || \ #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx926__) || defined(__gfx908__) || \
defined(__gfx90a__) // for GPU code defined(__gfx90a__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code #elif defined(__gfx1030__) // for GPU code
...@@ -46,7 +46,7 @@ ...@@ -46,7 +46,7 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32 #define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__)|| defined(__gfx926__) || defined(__gfx908__) || defined(__gfx90a__) || \ #elif defined(__gfx906__) || defined(__gfx926__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx1030__) // for GPU code defined(__gfx1030__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
......
...@@ -225,13 +225,13 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_ ...@@ -225,13 +225,13 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatA>(
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
constexpr auto threadwise_contraction = constexpr auto threadwise_contraction =
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
FloatA, FloatA,
FloatB, FloatA,
FloatC, FloatC,
decltype(a_thread_desc_bk0_bm0_bm1_bk1_), decltype(a_thread_desc_bk0_bm0_bm1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_), decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
...@@ -394,8 +394,8 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_ ...@@ -394,8 +394,8 @@ struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1<
FloatB, FloatB, // src
FloatB, FloatA, // dst
decltype(b_block_desc_bk0_bn0_bn1_bk1_), decltype(b_block_desc_bk0_bn0_bn1_bk1_),
decltype(b_thread_desc_bk0_bn0_bn1_bk1_), decltype(b_thread_desc_bk0_bn0_bn1_bk1_),
Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths Sequence<BK0PerThread, 1, BN1PerThreadBN11, BK1>, // SliceLengths
......
...@@ -134,7 +134,7 @@ __global__ void ...@@ -134,7 +134,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)||defined(__gfx926__) || defined(__gfx1030__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx926__) || defined(__gfx1030__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -709,7 +709,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK ...@@ -709,7 +709,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
{ {
return false; return false;
} }
......
...@@ -106,7 +106,7 @@ __global__ void ...@@ -106,7 +106,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map, const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__)|| defined(__gfx926__) || defined(__gfx1030__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx926__) || defined(__gfx1030__))
// offset base pointer for each work-group // offset base pointer for each work-group
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -600,7 +600,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS ...@@ -600,7 +600,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
namespace ctc = tensor_layout::convolution; namespace ctc = tensor_layout::convolution;
// check device // check device
if(!(ck::get_device_name() == "gfx906" ||ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030"))
{ {
return false; return false;
} }
......
...@@ -1391,7 +1391,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1391,7 +1391,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906"||ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")) if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030"))
{ {
return false; return false;
} }
......
...@@ -205,6 +205,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -205,6 +205,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
using GridwiseGemm = using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize, GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType, ADataType,
BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
...@@ -364,6 +365,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -364,6 +365,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const auto kernel = const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType, ADataType,
BDataType,
CDataType, CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
...@@ -390,6 +392,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -390,6 +392,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const auto kernel = const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType, ADataType,
BDataType,
CDataType, CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
...@@ -416,6 +419,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -416,6 +419,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const auto kernel = const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType, ADataType,
BDataType,
CDataType, CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
...@@ -442,6 +446,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -442,6 +446,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
const auto kernel = const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType, ADataType,
BDataType,
CDataType, CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
...@@ -483,7 +488,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout, ...@@ -483,7 +488,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx906" ||ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030") if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx926" || ck::get_device_name() == "gfx1030")
{ {
return GridwiseGemm::CheckValidity( return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
...@@ -117,18 +117,11 @@ struct GridwiseGemmDlMultipleD_km_kn_mn ...@@ -117,18 +117,11 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
bool ret=(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) &&
K1 == a_grid_desc_k0_m_k1.GetLength(I2) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && K1 == b_grid_desc_k0_n_k1.GetLength(I2)) &&
(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0);
if(!ret){
std::cout<<"M="<<M<<" N="<<N<<" K0="<<K0<<" c_grid_desc_m_n[0]="<<c_grid_desc_m_n.GetLength(I0)<<" c_grid_desc_m_n[1]="<<c_grid_desc_m_n.GetLength(I1)
<<" b_grid_desc_k0_n_k1[0]="<<b_grid_desc_k0_n_k1.GetLength(I0)<<" b_grid_desc_k0_n_k1[2]="<<b_grid_desc_k0_n_k1.GetLength(I2)
<<" a_grid_desc_k0_m_k1[2]="<<a_grid_desc_k0_m_k1.GetLength(I2)
<<" K1="<<K1<<" MPerBlock="<<MPerBlock<<" NPerBlock="<<NPerBlock<<" K0PerBlock="<<K0PerBlock<<std::endl;
}
return ret;
} }
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
namespace ck { namespace ck {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatC, typename FloatC,
typename AGridDesc_K0_M0_M1_K1, typename AGridDesc_K0_M0_M1_K1,
typename BGridDesc_K0_N0_N1_K1, typename BGridDesc_K0_N0_N1_K1,
...@@ -30,23 +31,27 @@ __global__ void ...@@ -30,23 +31,27 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, kernel_gemm_dl_v1r3(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size_of_a =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); GridwiseGemm::GetSharedMemoryNumberOfByteToA() / sizeof(FloatA);
constexpr index_t shared_block_size_of_b =
GridwiseGemm::GetSharedMemoryNumberOfByteToB() / sizeof(FloatB);
__shared__ FloatAB p_shared_block[shared_block_size]; __shared__ FloatA p_shared_block_a[shared_block_size_of_a];
__shared__ FloatB p_shared_block_b[shared_block_size_of_b];
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
p_shared_block, p_shared_block_a,
p_shared_block_b,
a_grid_desc_k0_m0_m1_k1, a_grid_desc_k0_m0_m1_k1,
b_grid_desc_k0_n0_n1_k1, b_grid_desc_k0_n0_n1_k1,
c_grid_desc_m0_m10_m11_n0_n10_n11, c_grid_desc_m0_m10_m11_n0_n10_n11,
...@@ -56,7 +61,8 @@ __global__ void ...@@ -56,7 +61,8 @@ __global__ void
} }
template <index_t BlockSize, template <index_t BlockSize,
typename FloatAB, typename FloatA,
typename FloatB,
typename FloatAcc, typename FloatAcc,
typename FloatC, typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
...@@ -99,7 +105,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -99,7 +105,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{}; static constexpr auto K1 = Number<K1Value>{};
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByteToA()
{ {
// TODO: change this. I think it needs multi-dimensional alignment // TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1; constexpr auto max_lds_align = K1;
...@@ -122,7 +128,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -122,7 +128,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto b_block_aligned_space_size = constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); return 2 * a_block_aligned_space_size * sizeof(FloatA);
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByteToB()
{
// TODO: change this. I think it needs multi-dimensional alignment
constexpr auto max_lds_align = K1;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_aligned_space_size =
math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align);
return 2 * b_block_aligned_space_size * sizeof(FloatB);
} }
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
...@@ -145,14 +177,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -145,14 +177,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
{ {
const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); // M0 * N0
return grid_size; return grid_size;
} }
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
{ {
const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; // K0 > K0PerBlock ???
return has_main_k_block_loop; return has_main_k_block_loop;
} }
...@@ -170,7 +202,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -170,7 +202,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
const auto M = a_grid_desc_k0_m_k1.GetLength(I1); const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
const auto M1 = Number<MPerBlock>{}; const auto M1 = Number<MPerBlock>{}; // 128
const auto M0 = M / M1; const auto M0 = M / M1;
const auto a_grid_desc_k0_m0_m1_k1 = const auto a_grid_desc_k0_m0_m1_k1 =
...@@ -178,8 +210,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -178,8 +210,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(M0, M1)), make_unmerge_transform(make_tuple(M0, M1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), // K0, M, K1
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); // K0, M0, M1, K1
return a_grid_desc_k0_m0_m1_k1; return a_grid_desc_k0_m0_m1_k1;
} }
...@@ -190,7 +222,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -190,7 +222,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0);
const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
const auto N1 = Number<NPerBlock>{}; const auto N1 = Number<NPerBlock>{}; // 128
const auto N0 = N / N1; const auto N0 = N / N1;
const auto b_grid_desc_k0_n0_n1_k1 = const auto b_grid_desc_k0_n0_n1_k1 =
...@@ -198,8 +230,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -198,8 +230,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple(make_pass_through_transform(K0), make_tuple(make_pass_through_transform(K0),
make_unmerge_transform(make_tuple(N0, N1)), make_unmerge_transform(make_tuple(N0, N1)),
make_pass_through_transform(K1)), make_pass_through_transform(K1)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), // K0, N, K1
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); // K0, N0, N1, K1
return b_grid_desc_k0_n0_n1_k1; return b_grid_desc_k0_n0_n1_k1;
} }
...@@ -210,33 +242,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -210,33 +242,33 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{}; // 128
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{}; // 128
const auto M0 = M / M1; const auto M0 = M / M1;
const auto N0 = N / N1; const auto N0 = N / N1;
constexpr auto M11 = constexpr auto M11 = // 64
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) * Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies{}, I1) * // S<8, 2> ==> 8*2=16
M1PerThreadM111>{}; M1PerThreadM111>{}; // M1PerThread 4
constexpr auto N11 = constexpr auto N11 = // 64
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) * Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies{}, I1) * // 16
N1PerThreadN111>{}; N1PerThreadN111>{}; // N1PerThread 4
constexpr auto M10 = M1 / M11; constexpr auto M10 = M1 / M11; // 2
constexpr auto N10 = N1 / N11; constexpr auto N10 = N1 / N11; // 2
const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor(
c_grid_desc_m_n, c_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
make_unmerge_transform(make_tuple(N0, N10, N11))), make_unmerge_transform(make_tuple(N0, N10, N11))),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}), // M, N
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); // M0, M10, M11, N0, N10, N11
return c_grid_desc_m0_m10_m11_n0_n10_n11; return c_grid_desc_m0_m10_m11_n0_n10_n11;
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping // what a fuck ???????????? 到底生成了啥
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
...@@ -252,10 +284,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -252,10 +284,11 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_grid, Run(const FloatA* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block, FloatA* __restrict__ p_shared_block_a,
FloatB* __restrict__ p_shared_block_b,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
...@@ -304,12 +337,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -304,12 +337,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// TODO: check alignment // TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM // A matrix in LDS memory, for blockwise GEMM
constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align); // (16, 128, 2), 2
// TODO: check alignment // TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM // B matrix in LDS memory, for blockwise GEMM
constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align); // (16, 128, 2), 2
static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() ==
a_k0_m_k1_block_desc.GetElementSpaceSize() && a_k0_m_k1_block_desc.GetElementSpaceSize() &&
...@@ -325,8 +358,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -325,8 +358,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatA,
FloatAB, FloatA,
remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>, remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_block_desc_k0_m0_m1_k1), decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
...@@ -349,8 +382,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -349,8 +382,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatB,
FloatAB, FloatB,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>, remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_block_desc_k0_n0_n1_k1), decltype(b_block_desc_k0_n0_n1_k1),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
...@@ -368,14 +401,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -368,14 +401,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS // a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS // b_mtx[K0PerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
const auto blockwise_gemm = const auto blockwise_gemm =
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
BlockSize, BlockSize,
FloatAB, FloatA, // todo split a/b
FloatAB, FloatB,
FloatAcc, FloatAcc,
decltype(a_k0_m_k1_block_desc), decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_block_desc), decltype(b_k0_n_k1_block_desc),
...@@ -400,8 +433,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -400,8 +433,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto b_block_aligned_space_size = math::integer_least_multiple( constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block; FloatA* p_a_block_double = p_shared_block_a;
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; FloatB* p_b_block_double = p_shared_block_b;
// register allocation for output // register allocation for output
auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>( auto c_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAcc>(
...@@ -436,7 +469,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -436,7 +469,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
if constexpr(HasMainKBlockLoop) if constexpr(HasMainKBlockLoop)
{ {
const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); // K / K1(=2)
index_t k_block_data_begin = 0; index_t k_block_data_begin = 0;
...@@ -487,7 +520,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 ...@@ -487,7 +520,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf);
k_block_data_begin += 2 * K0PerBlock; k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock); } while(k_block_data_begin < K0 - 2 * K0PerBlock); // K0PerBlock = 16
} }
// LDS double buffer: tail // LDS double buffer: tail
......
...@@ -151,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1 ...@@ -151,7 +151,7 @@ struct ThreadwiseTensorSliceTransfer_v4r1
dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); dst_origin_idx + data_to_origin_disp_idx + src_vector_idx);
dst_buf(Number<dst_offset>{}) = type_convert<DstData>( dst_buf(Number<dst_offset>{}) = type_convert<DstData>(
src_vector.template AsType<DstData>()[Number<src_vector_offset>{}]); src_vector.template AsType<SrcData>()[Number<src_vector_offset>{}]);
}); });
}); });
} }
......
...@@ -942,6 +942,10 @@ using int8x16_t = typename vector_type<int8_t, 16>::type; ...@@ -942,6 +942,10 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type; using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// u8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
// Convert X to Y // Convert X to Y
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
...@@ -951,6 +955,13 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -951,6 +955,13 @@ __host__ __device__ constexpr Y type_convert(X x)
return static_cast<Y>(x); return static_cast<Y>(x);
} }
// Convert X to Y
template <>
__host__ __device__ constexpr half_t type_convert<half_t, uint8_t>(uint8_t x)
{
return static_cast<half_t>(x);
}
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
......
...@@ -9,6 +9,31 @@ namespace ck { ...@@ -9,6 +9,31 @@ namespace ck {
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
__device__ void inner_product(const TA& a, const TB& b, TC& c); __device__ void inner_product(const TA& a, const TB& b, TC& c);
template <>
__device__ void inner_product<half2_t, uint8x2_t, float>(const half2_t& a, const uint8x2_t& b, float& c)
{
const vector_type<half_t, 2> a_vector{a};
const vector_type<uint8_t, 2> b_vector{b};
const vector_type<half_t, 2> b_fp16_vector;
static constexpr uint32_t mask_for_elt_01 = 0x05020500;
static constexpr uint32_t mask_for_elt_23 = 0x05030501;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
asm volatile("v_perm_b32 %0,%1,%2,%3;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[0]) : "v"(start_byte_for_fp16), "v"(((uint32_t*)&b_vector.data_)[0]), "v"(mask_for_elt_01));
// asm volatile("v_perm_b32 %0,%1,%2,%3;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[1]) : "v"(start_byte_for_fp16), "v"(((uint32_t*)&b_vector.data_)[0]), "v"(mask_for_elt_23));
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x6480;
asm volatile("v_sub_f16x2 %0, %1, %2;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[0]) : "v"(((uint32_t*)&b_fp16_vector.data_)[0]), "v"(I8s_TO_F16s_MAGIC_NUM));
// asm volatile("v_sub_f16x2 %0, %1, %2;\n" : "=v"(((uint32_t*)&b_fp16_vector.data_)[1]) : "v"(((uint32_t*)&b_fp16_vector.data_)[1]), "v"(I8s_TO_F16s_MAGIC_NUM));
static_for<0, 2, 1>{}([&](auto i) {
c += type_convert<int32_t>(a_vector.AsType<half_t>()[i]) *
type_convert<int32_t>(b_fp16_vector.AsType<half_t>()[i]);
});
}
template <> template <>
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c) __device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
{ {
...@@ -71,12 +96,6 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f ...@@ -71,12 +96,6 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c); c);
} }
template <>
__device__ void inner_product<half_t, half_t, float>(const half_t& a, const half_t& b, float& c)
{
c+=static_cast<float>(a*b);
}
template <> template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c) __device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{ {
......
...@@ -15,6 +15,7 @@ namespace device { ...@@ -15,6 +15,7 @@ namespace device {
namespace instance { namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using I8 = int8_t;
using F32 = float; using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -34,13 +35,13 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< ...@@ -34,13 +35,13 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple<
// #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | |
// #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> DeviceGemmDl< F16, I8, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; >;
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>& DeviceGemm<Col, Row, Row, F16, I8, F16, PassThrough, PassThrough, PassThrough>>>&
instances) instances)
{ {
add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{}); add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment