"...composable_kernel.git" did not exist on "970fa3e92ec4e67cfbfe1b0428e84870663ab8cd"
Commit 90b3ccac authored by Chao Liu's avatar Chao Liu
Browse files

recovering code

parent a7c587ee
...@@ -95,28 +95,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -95,28 +95,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
} }
#if 0
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
}
#endif
template <typename FloatA, typename FloatB, typename FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ void __device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
...@@ -352,5 +330,366 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -352,5 +330,366 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
} }
}; };
// blockwise GEMM: C += transpose(A) * B
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemm_km_kn_m0m1n0n1_v1
{
struct MatrixIndex
{
index_t row;
index_t col;
};
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n");
#if 0
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong");
#else
constexpr auto tmp0 = GetThreadMatrixCLengths()[I0];
constexpr auto tmp1 = GetThreadMatrixCLengths()[I1];
static_assert(tmp0 == 8, "wrong!");
static_assert(tmp1 == 8, "wrong!");
static_assert(tmp0 == Number<8>{}, "wrong!");
static_assert(tmp1 == Number<8>{}, "wrong!");
static_assert(ThreadMatrixC{}.GetLength(I0) == tmp0 &&
ThreadMatrixC{}.GetLength(I1) == tmp1,
"wrong! ThreadMatrixC lengths is wrong");
#endif
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row));
mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col));
}
__device__ static constexpr auto GetThreadMatrixCLengths()
{
constexpr auto I1 = Number<1>{};
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
constexpr index_t NRepeat =
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
static_assert(M == 128, "wrong!");
static_assert(MPerThreadSubC == 4, "wrong!");
static_assert(MRepeat == 2, "wrong!");
static_assert(NRepeat == 2, "wrong!");
static_assert(NPerThreadSubC == 4, "wrong!");
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1ThreadCluster;
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0ThreadCluster;
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t K = a_block_mtx[I0];
constexpr index_t MPerThread = c_thread_mtx[I0];
constexpr index_t NPerThread = c_thread_mtx[I1];
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA,
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx)>{};
#pragma unroll
// loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
// read A
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset(
make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) +
mMyThreadOffsetA,
p_a_thread + a_thread_mtx.CalculateOffset(
make_tuple(0, m_repeat * MPerThreadSubC)));
}
#pragma unroll
// read B
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
b_thread_copy.Run(p_b_block +
b_block_mtx.CalculateOffset(
make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) +
mMyThreadOffsetB,
p_b_thread + b_thread_mtx.CalculateOffset(
make_tuple(0, n_repeat * NPerThreadSubC)));
}
// C += A * B
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
}
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2,
"wrong! inline asm cannot deal with this GEMM config yet");
// thread A, B
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = a_thread_mtx.MakeSubMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{});
constexpr auto b_thread_sub_mtx = b_thread_mtx.MakeSubMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{});
// thread C-sub
constexpr auto c_thread_sub_mtx = ThreadMatrixC::MakeSubMatrixDescriptor(
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA,
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm =
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
// read A_sub_0
a_thread_copy.Run(p_a_block_off, p_a_thread);
// read B_sub_0
b_thread_copy.Run(p_b_block_off, p_b_thread);
// read B_sub_1
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(0, NPerLevel1Cluster),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC));
// read A_sub_1
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(0, MPerLevel1Cluster),
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC));
#pragma unroll
// loop over rest of k
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
{
// read A_sub_0
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, 0), p_a_thread);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
p_b_thread,
p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0));
// read B_sub_0
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, 0), p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread +
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC));
// read B_sub_1
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, NPerLevel1Cluster),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC));
// read A_sub_1
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, MPerLevel1Cluster),
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
p_b_thread,
p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
p_c_thread +
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC));
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0);
constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1);
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
}
else
{
Run_naive(p_a_block, p_b_block, p_c_thread);
}
#else
Run_naive(p_a_block, p_b_block, p_c_thread);
#endif
}
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -130,12 +130,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -130,12 +130,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// A matrix in LDS memory, dst of blockwise copy // A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_multi_index(KPerBlock, MPerBlock), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<MPerBlock>{}), Number<max_lds_align>{});
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
make_multi_index(KPerBlock, NPerBlock), max_lds_align); make_tuple(Number<KPerBlock>{}, Number<NPerBlock>{}), Number<max_lds_align>{});
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = auto a_blockwise_copy =
...@@ -201,6 +201,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -201,6 +201,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// b_mtx[KPerBlocl, NPerBlock] is in LDS // b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register // register
#if 0
constexpr index_t a_k_m_block_mtx_stride = constexpr index_t a_k_m_block_mtx_stride =
a_k_m_block_desc.CalculateOffset(make_multi_index(1, 0)) - a_k_m_block_desc.CalculateOffset(make_multi_index(1, 0)) -
a_k_m_block_desc.CalculateOffset(make_multi_index(0, 0)); a_k_m_block_desc.CalculateOffset(make_multi_index(0, 0));
...@@ -212,6 +213,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -212,6 +213,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
Number<KPerBlock>{}, Number<MPerBlock>{}, Number<a_k_m_block_mtx_stride>{}); Number<KPerBlock>{}, Number<MPerBlock>{}, Number<a_k_m_block_mtx_stride>{});
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor( constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerBlock>{}, Number<NPerBlock>{}, Number<b_k_n_block_mtx_stride>{}); Number<KPerBlock>{}, Number<NPerBlock>{}, Number<b_k_n_block_mtx_stride>{});
#endif
// sanity check // sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
...@@ -223,23 +225,28 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -223,23 +225,28 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// c_thread_mtx definition: this is a mess // c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx // TODO:: more elegent way of defining c_thread_mtx
#if 0
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}); Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{});
#else
constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{}));
#endif
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< const auto blockwise_gemm =
BlockSize, BlockwiseGemm_km_kn_m0m1n0n1_v1<BlockSize,
decltype(a_k_m_block_mtx_desc), decltype(a_k_m_block_desc),
decltype(b_k_n_block_mtx_desc), decltype(b_k_n_block_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc), decltype(c_m0m1_n0n1_thread_desc),
MPerThread, MPerThread,
NPerThread, NPerThread,
KPerThread, KPerThread,
MLevel0Cluster, MLevel0Cluster,
NLevel0Cluster, NLevel0Cluster,
MLevel1Cluster, MLevel1Cluster,
NLevel1Cluster, NLevel1Cluster,
MPerThread, MPerThread,
NPerThread>{}; NPerThread>{};
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space_size = constexpr index_t a_block_space_size =
...@@ -252,10 +259,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -252,10 +259,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; Float* p_b_block_double = p_shared_block + 2 * a_block_space_size;
// register allocation for output // register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; AccFloat p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread); threadwise_matrix_set_zero(c_m0m1_n0n1_thread_desc, p_c_thread);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0);
...@@ -422,7 +429,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1 ...@@ -422,7 +429,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
AddressSpace::Global, AddressSpace::Global,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true,
true>(c_m0_m1_n0_n1_global_desc, true>(c_m0_m1_n0_n1_global_desc,
make_multi_index(m_thread_data_on_global / M1, make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1, m_thread_data_on_global % M1,
......
...@@ -44,12 +44,12 @@ template <typename SrcData, ...@@ -44,12 +44,12 @@ template <typename SrcData,
AddressSpace DstAddressSpace, AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp, InMemoryDataOperation DstInMemOp,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun,
bool DstResetCoordinateAfterRun> bool DstResetCoordinateAfterRun>
struct ThreadwiseDynamicTensorSliceTransfer_v1r3 struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using Index = MultiIndex<nDim>;
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
...@@ -61,10 +61,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -61,10 +61,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{ {
} }
#if 0
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3() __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3()
: ThreadwiseDynamicTensorSliceTransfer_v1r3(DstDesc{}, make_zero_multi_index<nDim>()) : ThreadwiseDynamicTensorSliceTransfer_v1r3(DstDesc{}, make_zero_multi_index<nDim>())
{ {
} }
#endif
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{ {
...@@ -297,7 +299,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -297,7 +299,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
return forward_sweep; return forward_sweep;
}(); }();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by // calculate dst data index after last iteration in Run(), if it has not being reset by
// RunWrite() // RunWrite()
constexpr auto dst_data_idx = [&]() { constexpr auto dst_data_idx = [&]() {
Index ordered_idx; Index ordered_idx;
...@@ -328,7 +330,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -328,7 +330,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx) const Index& dst_slice_origin_step_idx)
{ {
// if dst coord was not reset by RunWrite(), then need to adjust the step here // if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep(); : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
...@@ -344,6 +346,326 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 ...@@ -344,6 +346,326 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
DstCoord dst_slice_origin_coord_; DstCoord dst_slice_origin_coord_;
}; // namespace ck }; // namespace ck
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// Assume dst_slice_origin_idx is 0
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t SrcVectorDim,
index_t SrcScalarPerVector,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun>
struct ThreadwiseDynamicTensorSliceTransfer_v2
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc,
const Index& src_slice_origin_idx)
: src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx))
{
}
#if 0
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2()
: ThreadwiseDynamicTensorSliceTransfer_v1r3(SrcDesc{}, make_zero_multi_index<nDim>())
{
}
#endif
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
}
template <typename SrcIteratorHacks>
__device__ void Run(const SrcDesc& src_desc,
const SrcData* p_src,
DstData* p_dst,
const SrcIteratorHacks& src_iterator_hacks)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
// Comments: dst_desc is constexpr
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// make forward iterators
const auto src_forward_iterators = generate_tuple(
[&](auto i) {
Index forward_step;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_dynamic_tensor_coordinate_iterator(
src_desc, forward_step, src_iterator_hacks[I0][i]);
},
Number<nDim>{});
// make backward iterators
const auto src_backward_iterators = generate_tuple(
[&](auto i) {
Index backward_step;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_dynamic_tensor_coordinate_iterator(
src_desc, backward_step, src_iterator_hacks[I1][i]);
},
Number<nDim>{});
// loop over tensor and copy
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i]
? ordered_access_idx[i]
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
});
auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) *
src_scalar_per_access;
return src_data_idx;
}();
// copy data
// hardcoding for buffer_store
// TODO refactor transfer_data() to encapsulate this
static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for ds_read");
vector_type<SrcData, SrcScalarPerVector> src_vector;
using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::MemoryType;
if constexpr(SrcAddressSpace == AddressSpace::Global)
{
src_vector.Vector() = amd_buffer_load<SrcData, SrcScalarPerVector>(
p_src,
src_slice_origin_coord_.GetOffset(),
true,
src_desc.GetElementSpaceSize());
const bool is_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
src_vector.Vector() = is_valid ? src_vector.Vector() : src_vector_t{0};
}
else
{
const bool is_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_);
src_vector.Vector() = is_valid ? *reinterpret_cast<const src_vector_t*>(
&p_src[src_slice_origin_coord_.GetOffset()])
: src_vector_t{0};
}
// this is hardcoded for dst that has compile-time tensor descriptor
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// assume dst_slice_origin_idx is 0
// TODO: support non-zero dst_slice_oring_idx
constexpr index_t dst_offset =
dst_desc.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
p_dst[Number<dst_offset>{}] = src_vector[i];
});
constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
});
});
return move_on_dim;
}
();
// move
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_dynamic_tensor_coordinate(src_desc,
src_slice_origin_coord_,
src_forward_iterators[dim_access_order[i]]);
}
else
{
move_dynamic_tensor_coordinate(src_desc,
src_slice_origin_coord_,
src_backward_iterators[dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_iterator =
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, src_reset_iterator);
}
}
__device__ void Run(const SrcDesc& src_desc, const SrcData* p_src, DstData* p_dst)
{
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
constexpr auto src_iterator_hacks =
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
Run(src_desc, p_src, p_dst, src_iterator_hacks);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto I0 = Number<0>{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_lengths[I0] - 1;
static_for<0, i, 1>{}([&](auto j) {
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
});
forward_sweep(i) = tmp % 2 == 0;
});
return forward_sweep;
}();
// calculate src data index after last iteration in Run(), if it has not being reset by
// RunWrite()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) *
src_scalar_per_access;
return src_data_idx;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; });
return reset_src_data_step;
}();
return reset_src_data_step;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step =
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step);
}
private:
SrcCoord src_slice_origin_coord_;
}; // namespace ck
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory // this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions // and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor // 1. It does not keep reference to tensor descriptor
...@@ -398,11 +720,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -398,11 +720,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
"wrong!"); "wrong!");
} }
#if 0
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3() __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3()
: ThreadwiseDynamicTensorSliceTransfer_v3( : ThreadwiseDynamicTensorSliceTransfer_v3(
SrcDesc{}, make_zero_multi_index<nDim>(), DstDesc{}, make_zero_multi_index<nDim>()) SrcDesc{}, make_zero_multi_index<nDim>(), DstDesc{}, make_zero_multi_index<nDim>())
{ {
} }
#endif
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{ {
...@@ -512,7 +836,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -512,7 +836,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
vector_type<SrcData, SrcScalarPerVector> src_vector; vector_type<SrcData, SrcScalarPerVector> src_vector;
using SrcVectorType = typename vector_type<SrcData, SrcScalarPerVector>::MemoryType; using src_vector_t = typename vector_type<SrcData, SrcScalarPerVector>::MemoryType;
#if 1 #if 1
src_vector.Vector() = amd_buffer_load<SrcData, SrcScalarPerVector>( src_vector.Vector() = amd_buffer_load<SrcData, SrcScalarPerVector>(
...@@ -521,7 +845,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3 ...@@ -521,7 +845,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const bool is_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( const bool is_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
src_desc, src_slice_origin_coord_); src_desc, src_slice_origin_coord_);
src_vector.Vector() = is_valid ? src_vector.Vector() : SrcVectorType{0}; src_vector.Vector() = is_valid ? src_vector.Vector() : src_vector_t{0};
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t buffer_offset = constexpr index_t buffer_offset =
......
...@@ -10,6 +10,7 @@ namespace ck { ...@@ -10,6 +10,7 @@ namespace ck {
template <typename Float, class Matrix> template <typename Float, class Matrix>
__device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread) __device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
{ {
#if 0
for(index_t i = 0; i < Matrix::NRow(); ++i) for(index_t i = 0; i < Matrix::NRow(); ++i)
{ {
for(index_t j = 0; j < Matrix::NCol(); ++j) for(index_t j = 0; j < Matrix::NCol(); ++j)
...@@ -18,6 +19,21 @@ __device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread) ...@@ -18,6 +19,21 @@ __device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
p_thread[id] = Float(0); p_thread[id] = Float(0);
} }
} }
#else
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto M = Matrix{}.GetLength(I0);
constexpr auto N = Matrix{}.GetLength(I1);
static_for<0, M, 1>{}([&](auto i) {
static_for<0, N, 1>{}([&](auto j) {
constexpr auto offset = Matrix{}.CalculateOffset(make_tuple(i, j));
p_thread[offset] = Float(0);
});
});
#endif
} }
template <typename SrcMatrix, template <typename SrcMatrix,
...@@ -32,6 +48,7 @@ struct ThreadwiseMatrixSliceCopy ...@@ -32,6 +48,7 @@ struct ThreadwiseMatrixSliceCopy
static_assert(SrcMatrix::RowStride() % DataPerAccess == 0 && static_assert(SrcMatrix::RowStride() % DataPerAccess == 0 &&
DstMatrix::RowStride() % DataPerAccess == 0, DstMatrix::RowStride() % DataPerAccess == 0,
"wrong! wrong alignment"); "wrong! wrong alignment");
static_assert(NSliceCol % DataPerAccess == 0, static_assert(NSliceCol % DataPerAccess == 0,
"wrong! should be NSliceCol % DataPerAccess == 0"); "wrong! should be NSliceCol % DataPerAccess == 0");
} }
...@@ -41,6 +58,7 @@ struct ThreadwiseMatrixSliceCopy ...@@ -41,6 +58,7 @@ struct ThreadwiseMatrixSliceCopy
{ {
using vector_t = typename vector_type<Data, DataPerAccess>::MemoryType; using vector_t = typename vector_type<Data, DataPerAccess>::MemoryType;
#if 0
for(index_t i = 0; i < NSliceRow; ++i) for(index_t i = 0; i < NSliceRow; ++i)
{ {
for(index_t j = 0; j < NSliceCol; j += DataPerAccess) for(index_t j = 0; j < NSliceCol; j += DataPerAccess)
...@@ -52,6 +70,17 @@ struct ThreadwiseMatrixSliceCopy ...@@ -52,6 +70,17 @@ struct ThreadwiseMatrixSliceCopy
*reinterpret_cast<const vector_t*>(&p_src[src_index]); *reinterpret_cast<const vector_t*>(&p_src[src_index]);
} }
} }
#else
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
constexpr auto src_offset = SrcMatrix{}.CalculateOffset(make_tuple(i, j));
constexpr auto dst_offset = DstMatrix{}.CalculateOffset(make_tuple(i, j));
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
});
});
#endif
} }
}; };
...@@ -62,85 +91,95 @@ struct ThreadwiseGemmTransANormalBNormalC ...@@ -62,85 +91,95 @@ struct ThreadwiseGemmTransANormalBNormalC
{ {
__device__ constexpr ThreadwiseGemmTransANormalBNormalC() __device__ constexpr ThreadwiseGemmTransANormalBNormalC()
{ {
#if 0
static_assert(MatrixA::NRow() == MatrixB::NRow() && MatrixA::NCol() == MatrixC::NRow() && static_assert(MatrixA::NRow() == MatrixB::NRow() && MatrixA::NCol() == MatrixC::NRow() &&
MatrixB::NCol() == MatrixC::NCol(), MatrixB::NCol() == MatrixC::NCol(),
"wrong!"); "wrong!");
#endif
} }
template <typename FloatA, typename FloatB, typename FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
constexpr index_t M = MatrixC::NRow(); constexpr auto I0 = Number<0>{};
constexpr index_t N = MatrixC::NCol(); constexpr auto I1 = Number<1>{};
constexpr index_t K = MatrixA::NRow(); // A is transposed
constexpr index_t M = MatrixC{}[I0];
for(index_t k = 0; k < K; ++k) constexpr index_t N = MatrixC{}[I1];
{ constexpr index_t K = MatrixA{}[I0];
for(index_t m = 0; m < M; ++m)
{ static_for<0, K, 1>{}([&](auto k){
for(index_t n = 0; n < N; ++n) static_for<0, M, 1>{}([&](auto m){
{ static_for<0, N, 1>{}([&](auto n){
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed const index_t a_offset =
const index_t bindex = MatrixB::CalculateOffset(k, n); MatrixA{}.CalculateOffset(make_tuple(k, m)); // A is transposed
const index_t cindex = MatrixC::CalculateOffset(m, n); const index_t b_offset = MatrixB{}.CalculateOffset(make_tuple(k, n));
const index_t c_offset = MatrixC{}.CalculateOffset(make_tuple(m, n));
p_c[cindex] +=
inner_product_with_conversion<FloatC>{}(p_a[aindex], p_b[bindex]); p_c[c_offset] +=
} inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
} });
} });
});
} }
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM #if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename FloatA, typename FloatB, typename FloatC> template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) __device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{ {
constexpr index_t M = MatrixC::NRow(); constexpr auto I0 = Number<0>{};
constexpr index_t N = MatrixC::NCol(); constexpr auto I1 = Number<1>{};
constexpr index_t K = MatrixA::NRow(); // A is transposed constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet"); constexpr index_t M = MatrixC{}[I0];
constexpr index_t N = MatrixC{}[I1];
constexpr index_t K = MatrixA{}[I0];
for(index_t k = 0; k < K; ++k) static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
{
for(index_t m = 0; m < M; ++m)
{
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed
static_if<N == 2>{}([&](auto) { static_for<0, K, 1>{}([&](auto k){
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0); static_for<0, M, 1>{}([&](auto m){
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1); constexpr auto a_offset = MatrixA{}.CalculateOffset(make_tuple(k, m));
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0); if constexpr(N == 2)
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1); {
constexpr auto b_offset_0 = MatrixB{}.CalculateOffset(make_tuple(k, I0));
constexpr auto b_offset_1 = MatrixB{}.CalculateOffset(make_tuple(k, I1));
amd_assembly_outer_product_1x2( constexpr auto c_offset_0 = MatrixC{}.CalculateOffset(make_tuple(m, I0));
p_a[aindex], p_b[bindex_0], p_b[bindex_1], p_c[cindex_0], p_c[cindex_1]); constexpr auto c_offset_1 = MatrixC{}.CalculateOffset(make_tuple(m, I1));
});
static_if<N == 4>{}([&](auto) { amd_assembly_outer_product_1x2(p_a[a_offset],
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0); p_b[b_offset_0],
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1); p_b[b_offset_1],
const index_t bindex_2 = MatrixB::CalculateOffset(k, 2); p_c[c_offset_0],
const index_t bindex_3 = MatrixB::CalculateOffset(k, 3); p_c[c_offset_1]);
}
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0); else if constexpr(N == 4)
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1); {
const index_t cindex_2 = MatrixC::CalculateOffset(m, 2); constexpr auto b_offset_0 = MatrixB{}.CalculateOffset(make_tuple(k, I0));
const index_t cindex_3 = MatrixC::CalculateOffset(m, 3); constexpr auto b_offset_1 = MatrixB{}.CalculateOffset(make_tuple(k, I1));
constexpr auto b_offset_2 = MatrixB{}.CalculateOffset(make_tuple(k, I2));
amd_assembly_outer_product_1x4(p_a[aindex], constexpr auto b_offset_3 = MatrixB{}.CalculateOffset(make_tuple(k, I3));
p_b[bindex_0],
p_b[bindex_1], constexpr auto c_offset_0 = MatrixC{}.CalculateOffset(make_tuple(m, I0));
p_b[bindex_2], constexpr auto c_offset_1 = MatrixC{}.CalculateOffset(make_tuple(m, I1));
p_b[bindex_3], constexpr auto c_offset_2 = MatrixC{}.CalculateOffset(make_tuple(m, I2));
p_c[cindex_0], constexpr auto c_offset_3 = MatrixC{}.CalculateOffset(make_tuple(m, I3));
p_c[cindex_1],
p_c[cindex_2], amd_assembly_outer_product_1x4(p_a[a_offset],
p_c[cindex_3]); p_b[b_offset_0],
}); p_b[b_offset_1],
} p_b[b_offset_2],
} p_b[b_offset_3],
p_c[c_offset_0],
p_c[c_offset_1],
p_c[c_offset_2],
p_c[c_offset_3]);
}
});
});
} }
#endif #endif
...@@ -153,8 +192,14 @@ struct ThreadwiseGemmTransANormalBNormalC ...@@ -153,8 +192,14 @@ struct ThreadwiseGemmTransANormalBNormalC
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) || (is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{})); (is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
static_if<has_amd_asm>{}([&](auto fwd) { Run_amd_asm(p_a, p_b, fwd(p_c)); }) if constexpr(has_amd_asm)
.Else([&](auto) { Run_source(p_a, p_b, p_c); }); {
Run_amd_asm(p_a, p_b, p_c);
}
else
{
Run_source(p_a, p_b, p_c);
}
#else #else
Run_source(p_a, p_b, p_c); Run_source(p_a, p_b, p_c);
#endif #endif
......
...@@ -114,8 +114,8 @@ __host__ __device__ constexpr T min(T x, Ts... xs) ...@@ -114,8 +114,8 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
} }
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
template <typename X, typename Y> template <typename T>
__host__ __device__ constexpr auto gcd(X x, Y y) __host__ __device__ constexpr T gcd(T x, T y)
{ {
if(x == y || x == 0) if(x == y || x == 0)
{ {
...@@ -135,6 +135,14 @@ __host__ __device__ constexpr auto gcd(X x, Y y) ...@@ -135,6 +135,14 @@ __host__ __device__ constexpr auto gcd(X x, Y y)
} }
} }
template<index_t X, index_t Y>
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
{
constexpr auto r = gcd(X, Y);
return Number<r>{};
}
template <typename X, typename... Ys> template <typename X, typename... Ys>
__host__ __device__ constexpr auto gcd(X x, Ys... ys) __host__ __device__ constexpr auto gcd(X x, Ys... ys)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment