Commit cf9bd973 authored by Chao Liu's avatar Chao Liu
Browse files

refactoring blockwise gemm

parent 7d09790a
...@@ -2,7 +2,9 @@ ...@@ -2,7 +2,9 @@
#define CK_BLOCKWISE_GEMM_HPP #define CK_BLOCKWISE_GEMM_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_gemm.hpp" #include "threadwise_gemm.hpp"
namespace ck { namespace ck {
...@@ -38,16 +40,22 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -38,16 +40,22 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
__device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2() __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2()
{ {
static_assert(BlockMatrixA::GetNumOfDimension() == 2 &&
BlockMatrixB::GetNumOfDimension() == 2 &&
ThreadMatrixC::GetNumOfDimension() == 2,
"wrong! A, B, C matrix should be 2D tensors");
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster * constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel1ThreadCluster * NLevel1ThreadCluster; MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(), static_assert(BlockMatrixA::GetLengths()[0] == BlockMatrixB::GetLengths()[0],
"wrong! K dimension not consistent\n"); "wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed constexpr index_t M = BlockMatrixA::GetLengths()[1]; // A is transposed
constexpr index_t N = BlockMatrixB::NCol(); constexpr index_t N = BlockMatrixB::GetLengths()[1];
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 && static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0, N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
...@@ -59,14 +67,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -59,14 +67,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA::GetOffsetFromMultiIndex(0, c_thread_mtx_index.row); mMyThreadOffsetA = BlockMatrixA::CalculateOffset({0, c_thread_mtx_index.row});
mMyThreadOffsetB = BlockMatrixB::GetOffsetFromMultiIndex(0, c_thread_mtx_index.col); mMyThreadOffsetB = BlockMatrixB::CalculateOffset({0, c_thread_mtx_index.col});
} }
__device__ static constexpr auto GetThreadMatrixCLengths() __device__ static constexpr auto GetThreadMatrixCLengths()
{ {
constexpr index_t M = BlockMatrixA::NCol(); // A is transposed constexpr index_t M = BlockMatrixA::GetLengths()[1]; // A is transposed
constexpr index_t N = BlockMatrixB::NCol(); constexpr index_t N = BlockMatrixB::GetLengths()[1];
constexpr index_t MRepeat = constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster); M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
...@@ -125,8 +133,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -125,8 +133,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t K = a_block_mtx.NRow(); constexpr index_t K = a_block_mtx.NRow();
constexpr index_t MPerThread = c_thread_mtx.NRow(); constexpr index_t MPerThread = c_thread_mtx.GetLengths()[0];
constexpr index_t NPerThread = c_thread_mtx.NCol(); constexpr index_t NPerThread = c_thread_mtx.GetLengths()[1];
constexpr index_t MPerLevel1Cluster = constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
...@@ -138,25 +146,36 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -138,25 +146,36 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// thread A, B for GEMM // thread A, B for GEMM
constexpr auto a_thread_mtx = constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{}); make_native_tensor_descriptor_packed(Sequence<KPerThreadLoop, MPerThread>{});
constexpr auto b_thread_mtx = constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{}); make_native_tensor_descriptor_packed(Sequence<KPerThreadLoop, NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA, constexpr auto a_thread_copy =
ThreadwiseGenericTensorSliceCopy_v4r2<BlockMatrixA,
decltype(a_thread_mtx), decltype(a_thread_mtx),
KPerThreadLoop, Sequence<KPerThreadLoop, MPerThreadSubC>,
MPerThreadSubC, Sequence<0, 1>,
ThreadGemmADataPerRead_M>{}; 1,
ThreadGemmADataPerRead_M,
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB, ThreadGemmADataPerRead_M,
AddressSpace::Lds,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>({0, 0}, {0, 0});
constexpr auto b_thread_copy =
ThreadwiseGenericTensorSliceCopy_v4r2<BlockMatrixB,
decltype(b_thread_mtx), decltype(b_thread_mtx),
KPerThreadLoop, Sequence<KPerThreadLoop, NPerThreadSubC>,
NPerThreadSubC, Sequence<0, 1>,
ThreadGemmBDataPerRead_N>{}; 1,
ThreadGemmBDataPerRead_N,
ThreadGemmBDataPerRead_N,
AddressSpace::Lds,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>({0, 0}, {0, 0});
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx), ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx),
...@@ -171,9 +190,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -171,9 +190,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
a_thread_copy.Run( a_thread_copy.Run(
p_a_block + a_block_mtx.CalculateOffset(k_begin, m_repeat * MPerLevel1Cluster) + p_a_block +
a_block_mtx.CalculateOffset({k_begin, m_repeat * MPerLevel1Cluster}) +
mMyThreadOffsetA, mMyThreadOffsetA,
p_a_thread + a_thread_mtx.CalculateOffset(0, m_repeat * MPerThreadSubC)); p_a_thread + a_thread_mtx.CalculateOffset({0, m_repeat * MPerThreadSubC}));
} }
#pragma unroll #pragma unroll
...@@ -181,9 +201,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -181,9 +201,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{ {
b_thread_copy.Run( b_thread_copy.Run(
p_b_block + b_block_mtx.CalculateOffset(k_begin, n_repeat * NPerLevel1Cluster) + p_b_block +
b_block_mtx.CalculateOffset({k_begin, n_repeat * NPerLevel1Cluster}) +
mMyThreadOffsetB, mMyThreadOffsetB,
p_b_thread + b_thread_mtx.CalculateOffset(0, n_repeat * NPerThreadSubC)); p_b_thread + b_thread_mtx.CalculateOffset({0, n_repeat * NPerThreadSubC}));
} }
// C += A * B // C += A * B
...@@ -217,34 +238,47 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -217,34 +238,47 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
// thread A, B // thread A, B
constexpr auto a_thread_mtx = constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{}); make_native_tensor_descriptor_packed(Sequence<KPerThreadLoop, MPerThread>{});
constexpr auto b_thread_mtx = constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{}); make_native_tensor_descriptor_packed(Sequence<KPerThreadLoop, NPerThread>{});
// thread A-sub, B-sub // thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = a_thread_mtx.MakeSubMatrixDescriptor( constexpr auto a_thread_sub_mtx = make_native_tensor_descriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}); Sequence<KPerThreadLoop, MPerThreadSubC>{}, Sequence<MPerThread, 1>{});
constexpr auto b_thread_sub_mtx = b_thread_mtx.MakeSubMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}); constexpr auto b_thread_sub_mtx = make_native_tensor_descriptor(
Sequence<KPerThreadLoop, NPerThreadSubC>{}, Sequence<NPerThread, 1>{});
// thread C-sub // thread C-sub
constexpr auto c_thread_sub_mtx = ThreadMatrixC::MakeSubMatrixDescriptor( constexpr auto c_thread_sub_mtx = make_native_tensor_descriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}); Sequence<MPerThreadSubC, NPerThreadSubC>{}, Sequence<NPerThread, 1>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA, constexpr auto a_thread_copy =
decltype(a_thread_mtx), ThreadwiseGenericTnesorSliceCopy_v4r2<BlockMatrixA,
KPerThreadLoop, decltype(a_thread_sub_mtx),
MPerThreadSubC, decltype(a_thread_sub_mtx.GetLengths()),
ThreadGemmADataPerRead_M>{}; Sequence<0, 1>,
1,
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB, ThreadGemmADataPerRead_M,
ThreadGemmADataPerRead_M,
AddressSpace::Lds,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>({0, 0}, {0, 0});
constexpr auto b_thread_copy =
ThreadwiseGenericTnesorSliceCopy_v4r2<BlockMatrixB,
decltype(b_thread_mtx), decltype(b_thread_mtx),
KPerThreadLoop, decltype(b_thread_sub_mtx.GetLengths()),
NPerThreadSubC, Sequence<0, 1>,
ThreadGemmBDataPerRead_N>{}; 1,
ThreadGemmBDataPerRead_N,
ThreadGemmBDataPerRead_N,
AddressSpace::Lds,
AddressSpace::Vgpr,
InMemoryDataOperation::Set>({0, 0}, {0, 0});
constexpr auto threadwise_gemm = constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx), ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx),
...@@ -261,77 +295,77 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -261,77 +295,77 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
b_thread_copy.Run(p_b_block_off, p_b_thread); b_thread_copy.Run(p_b_block_off, p_b_thread);
// read B_sub_1 // read B_sub_1
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(0, NPerLevel1Cluster), b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset({0, NPerLevel1Cluster}),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC)); p_b_thread + b_thread_mtx.CalculateOffset({0, NPerThreadSubC}));
// read A_sub_1 // read A_sub_1
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(0, MPerLevel1Cluster), a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset({0, MPerLevel1Cluster}),
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC)); p_a_thread + a_thread_mtx.CalculateOffset({0, MPerThreadSubC}));
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(p_a_thread, threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC), p_b_thread + b_thread_mtx.CalculateOffset({0, NPerThreadSubC}),
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC)); p_c_thread + ThreadMatrixC::CalculateOffset({0, NPerThreadSubC}));
#pragma unroll #pragma unroll
// loop over rest of k // loop over rest of k
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop) for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
{ {
// read A_sub_0 // read A_sub_0
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, 0), p_a_thread); a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset({k, 0}), p_a_thread);
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC), threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset({0, MPerThreadSubC}),
p_b_thread, p_b_thread,
p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0)); p_c_thread + ThreadMatrixC::CalculateOffset({MPerThreadSubC, 0}));
// read B_sub_0 // read B_sub_0
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, 0), p_b_thread); b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset({k, 0}), p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC), threadwise_gemm.Run(
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC), p_a_thread + a_thread_mtx.CalculateOffset({0, MPerThreadSubC}),
p_c_thread + p_b_thread + b_thread_mtx.CalculateOffset({0, NPerThreadSubC}),
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC)); p_c_thread + ThreadMatrixC::CalculateOffset({MPerThreadSubC, NPerThreadSubC}));
// read B_sub_1 // read B_sub_1
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, NPerLevel1Cluster), b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset({k, NPerLevel1Cluster}),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC)); p_b_thread + b_thread_mtx.CalculateOffset({0, NPerThreadSubC}));
// read A_sub_1 // read A_sub_1
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, MPerLevel1Cluster), a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset({k, MPerLevel1Cluster}),
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC)); p_a_thread + a_thread_mtx.CalculateOffset({0, MPerThreadSubC}));
// C_sub_00 += transpose(A_sub_0) * B_sub_0 // C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1 // C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(p_a_thread, threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC), p_b_thread + b_thread_mtx.CalculateOffset({0, NPerThreadSubC}),
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC)); p_c_thread + ThreadMatrixC::CalculateOffset({0, NPerThreadSubC}));
} }
// C_sub_10 += transpose(A_sub_1) * B_sub_0 // C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC), threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset({0, MPerThreadSubC}),
p_b_thread, p_b_thread,
p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0)); p_c_thread + ThreadMatrixC::CalculateOffset({MPerThreadSubC, 0}));
// C_sub_11 += transpose(A_sub_1) * B_sub_1 // C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC), threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset({0, MPerThreadSubC}),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC), p_b_thread + b_thread_mtx.CalculateOffset({0, NPerThreadSubC}),
p_c_thread + p_c_thread +
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC)); ThreadMatrixC::CalculateOffset({MPerThreadSubC, NPerThreadSubC}));
} }
template <typename FloatA, typename FloatB, typename FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{ {
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE #if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr index_t MPerThread = ThreadMatrixC::NRow(); constexpr index_t MPerThread = ThreadMatrixC::GetLengths()[0];
constexpr index_t NPerThread = ThreadMatrixC::NCol(); constexpr index_t NPerThread = ThreadMatrixC::GetLengths()[1];
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp" #include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_op.hpp"
#include "blockwise_gemm.hpp" #include "blockwise_gemm.hpp"
namespace ck { namespace ck {
...@@ -177,28 +177,24 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -177,28 +177,24 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// b_mtx[KPerBlocl, NPerBlock] is in LDS // b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
// sanity check // sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!"); "wrong!");
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( constexpr auto c_m0m1_n0n1_thread_desc = make_native_tensor_descriptor_packed(
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{}); Sequence<GemmMRepeat * MPerThread, GemmNRepeat * NPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize, BlockSize,
decltype(a_k_m_block_mtx_desc), decltype(a_k_m_block_desc),
decltype(b_k_n_block_mtx_desc), decltype(b_k_n_block_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc), decltype(c_m0m1_n0n1_thread_desc),
MPerThread, MPerThread,
NPerThread, NPerThread,
MLevel0Cluster, MLevel0Cluster,
...@@ -220,10 +216,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -220,10 +216,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float* p_b_block_double = p_shared_block + 2 * a_block_space; Float* p_b_block_double = p_shared_block + 2 * a_block_space;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; AccFloat p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpace()];
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread); threadwise_generic_tensor_set_zero(c_m0m1_n0n1_thread_desc, p_c_thread);
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
......
...@@ -2,59 +2,11 @@ ...@@ -2,59 +2,11 @@
#define CK_THREADWISE_GEMM_HPP #define CK_THREADWISE_GEMM_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp" #include "tensor_descriptor.hpp"
#include "math.hpp" #include "math.hpp"
namespace ck { namespace ck {
template <typename Float, class Matrix>
__device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
{
for(index_t i = 0; i < Matrix::NRow(); ++i)
{
for(index_t j = 0; j < Matrix::NCol(); ++j)
{
const index_t id = Matrix::CalculateOffset(i, j);
p_thread[id] = Float(0);
}
}
}
template <typename SrcMatrix,
typename DstMatrix,
index_t NSliceRow,
index_t NSliceCol,
index_t DataPerAccess>
struct ThreadwiseMatrixSliceCopy
{
__device__ constexpr ThreadwiseMatrixSliceCopy()
{
static_assert(SrcMatrix::RowStride() % DataPerAccess == 0 &&
DstMatrix::RowStride() % DataPerAccess == 0,
"wrong! wrong alignment");
static_assert(NSliceCol % DataPerAccess == 0,
"wrong! should be NSliceCol % DataPerAccess == 0");
}
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
using vector_t = typename vector_type<Data, DataPerAccess>::MemoryType;
for(index_t i = 0; i < NSliceRow; ++i)
{
for(index_t j = 0; j < NSliceCol; j += DataPerAccess)
{
const index_t src_index = SrcMatrix::CalculateOffset(i, j);
const index_t dst_index = DstMatrix::CalculateOffset(i, j);
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
}
}
}
};
// C += transpose(A) * B // C += transpose(A) * B
// Element of matrix can be vectorized data // Element of matrix can be vectorized data
template <typename MatrixA, typename MatrixB, typename MatrixC> template <typename MatrixA, typename MatrixB, typename MatrixC>
...@@ -62,17 +14,18 @@ struct ThreadwiseGemmTransANormalBNormalC ...@@ -62,17 +14,18 @@ struct ThreadwiseGemmTransANormalBNormalC
{ {
__device__ constexpr ThreadwiseGemmTransANormalBNormalC() __device__ constexpr ThreadwiseGemmTransANormalBNormalC()
{ {
static_assert(MatrixA::NRow() == MatrixB::NRow() && MatrixA::NCol() == MatrixC::NRow() && static_assert(MatrixA::GetLengths()[0] == MatrixB::GetLengths()[0] &&
MatrixB::NCol() == MatrixC::NCol(), MatrixA::GetlLengths()[1] == MatrixC::GetLengths()[0] &&
MatrixB::GetLengths()[1] == MatrixC::GetLenths()[1],
"wrong!"); "wrong!");
} }
template <typename FloatA, typename FloatB, typename FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
constexpr index_t M = MatrixC::NRow(); constexpr index_t M = MatrixC::GetLengths()[0];
constexpr index_t N = MatrixC::NCol(); constexpr index_t N = MatrixC::GetLengths()[1];
constexpr index_t K = MatrixA::NRow(); // A is transposed constexpr index_t K = MatrixA::GetLengths()[0]; // A is transposed
for(index_t k = 0; k < K; ++k) for(index_t k = 0; k < K; ++k)
{ {
...@@ -80,9 +33,9 @@ struct ThreadwiseGemmTransANormalBNormalC ...@@ -80,9 +33,9 @@ struct ThreadwiseGemmTransANormalBNormalC
{ {
for(index_t n = 0; n < N; ++n) for(index_t n = 0; n < N; ++n)
{ {
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed const index_t aindex = MatrixA::CalculateOffset({k, m}); // A is transposed
const index_t bindex = MatrixB::CalculateOffset(k, n); const index_t bindex = MatrixB::CalculateOffset({k, n});
const index_t cindex = MatrixC::CalculateOffset(m, n); const index_t cindex = MatrixC::CalculateOffset({m, n});
p_c[cindex] += p_c[cindex] +=
inner_product_with_conversion<FloatC>{}(p_a[aindex], p_b[bindex]); inner_product_with_conversion<FloatC>{}(p_a[aindex], p_b[bindex]);
...@@ -95,9 +48,9 @@ struct ThreadwiseGemmTransANormalBNormalC ...@@ -95,9 +48,9 @@ struct ThreadwiseGemmTransANormalBNormalC
template <typename FloatA, typename FloatB, typename FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
constexpr index_t M = MatrixC::NRow(); constexpr index_t M = MatrixC::GetLengths()[0];
constexpr index_t N = MatrixC::NCol(); constexpr index_t N = MatrixC::GetLengths()[1];
constexpr index_t K = MatrixA::NRow(); // A is transposed constexpr index_t K = MatrixA::GetLengths()[0]; // A is transposed
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet"); static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
...@@ -108,26 +61,26 @@ struct ThreadwiseGemmTransANormalBNormalC ...@@ -108,26 +61,26 @@ struct ThreadwiseGemmTransANormalBNormalC
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed
static_if<N == 2>{}([&](auto) { static_if<N == 2>{}([&](auto) {
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0); const index_t bindex_0 = MatrixB::CalculateOffset({k, 0});
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1); const index_t bindex_1 = MatrixB::CalculateOffset({k, 1});
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0); const index_t cindex_0 = MatrixC::CalculateOffset({m, 0});
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1); const index_t cindex_1 = MatrixC::CalculateOffset({m, 1});
amd_assembly_outer_product_1x2( amd_assembly_outer_product_1x2(
p_a[aindex], p_b[bindex_0], p_b[bindex_1], p_c[cindex_0], p_c[cindex_1]); p_a[aindex], p_b[bindex_0], p_b[bindex_1], p_c[cindex_0], p_c[cindex_1]);
}); });
static_if<N == 4>{}([&](auto) { static_if<N == 4>{}([&](auto) {
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0); const index_t bindex_0 = MatrixB::CalculateOffset({k, 0});
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1); const index_t bindex_1 = MatrixB::CalculateOffset({k, 1});
const index_t bindex_2 = MatrixB::CalculateOffset(k, 2); const index_t bindex_2 = MatrixB::CalculateOffset({k, 2});
const index_t bindex_3 = MatrixB::CalculateOffset(k, 3); const index_t bindex_3 = MatrixB::CalculateOffset({k, 3});
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0); const index_t cindex_0 = MatrixC::CalculateOffset({m, 0});
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1); const index_t cindex_1 = MatrixC::CalculateOffset({m, 1});
const index_t cindex_2 = MatrixC::CalculateOffset(m, 2); const index_t cindex_2 = MatrixC::CalculateOffset({m, 2});
const index_t cindex_3 = MatrixC::CalculateOffset(m, 3); const index_t cindex_3 = MatrixC::CalculateOffset({m, 3});
amd_assembly_outer_product_1x4(p_a[aindex], amd_assembly_outer_product_1x4(p_a[aindex],
p_b[bindex_0], p_b[bindex_0],
......
...@@ -2,18 +2,14 @@ ...@@ -2,18 +2,14 @@
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP #define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp" #include "tensor_descriptor.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
namespace ck { namespace ck {
template <class Float, class TDesc> template <class Float, class TensorDesc>
__device__ void threadwise_generic_tensor_set_zero(TDesc, Float* __restrict__ p) __device__ void threadwise_generic_tensor_set_zero(TensorDesc, Float* __restrict__ p)
{ {
static_ford<decltype(TDesc::GetLengths())>{}([&](auto multi_id) { ford<decltype(TensorDesc::GetLengths())>{}(
constexpr index_t offset = TDesc::GetOffsetFromMultiIndex(multi_id); [&](auto idx) { p[TensorDesc::CalculateOffset(idx)] = static_cast<Float>(0); });
p[offset] = static_cast<Float>(0);
});
} }
} // namespace ck } // namespace ck
......
#ifndef CK_THREADWISE_GENERIC_TENSOR_OP_DEPRECATED_HPP
#define CK_THREADWISE_GENERIC_TENSOR_OP_DEPRECATED_HPP
#include "common_header.hpp"
#include "ConstantTensorDescriptor_deprecated.hpp"
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
namespace ck {
template <class Float, class TDesc>
__device__ void threadwise_generic_tensor_set_zero(TDesc, Float* __restrict__ p)
{
static_ford<decltype(TDesc::GetLengths())>{}([&](auto multi_id) {
constexpr index_t offset = TDesc::GetOffsetFromMultiIndex(multi_id);
p[offset] = static_cast<Float>(0);
});
}
} // namespace ck
#endif
...@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 0 #if 1
// BlockSize = 256, GemmKPerBlock = 8 // BlockSize = 256, GemmKPerBlock = 8
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
......
...@@ -18,18 +18,18 @@ ...@@ -18,18 +18,18 @@
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp" //#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp" //#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" //#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp" //#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" //#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp" //#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp" //#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp" //#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 1 #if 0
// 1x1 // 1x1
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 64; constexpr index_t C = 64;
...@@ -59,7 +59,7 @@ int main(int argc, char* argv[]) ...@@ -59,7 +59,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>;
#elif 0 #elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment