Commit dea22d0e authored by Chao Liu's avatar Chao Liu
Browse files

added IsKnownAtCompileTime() in multi-index transform and tensor descriptor

parent 55599afd
......@@ -74,6 +74,11 @@ struct DynamicPassThrough
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value;
}
__host__ __device__ void Print() const
{
printf("{");
......@@ -160,6 +165,13 @@ struct DynamicPad
(idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_));
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<LeftPad>::value &&
is_known_at_compile_time<RightPad>::value;
}
__host__ __device__ void Print() const
{
printf("{");
......@@ -243,6 +255,12 @@ struct DynamicLeftPad
return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_);
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<LeftPad>::value;
}
__host__ __device__ void Print() const
{
printf("{");
......@@ -328,6 +346,13 @@ struct DynamicRightPad
return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_);
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<LowLength>::value &&
is_known_at_compile_time<RightPad>::value;
}
__host__ __device__ void Print() const
{
printf("{");
......@@ -424,6 +449,12 @@ struct DynamicEmbed
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<Coefficients>::value;
}
__host__ __device__ void Print() const
{
printf("{");
......@@ -930,6 +961,13 @@ struct DynamicMerge
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<LowLengths>::value &&
is_known_at_compile_time<LowLengthsScan>::value &&
is_known_at_compile_time<UpLengths>::value;
}
template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
......@@ -1033,6 +1071,12 @@ struct DynamicUnMerge
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<UpLengths>::value &&
is_known_at_compile_time<UpLengthsScan>::value;
}
__host__ __device__ void Print() const
{
printf("{");
......@@ -1097,6 +1141,11 @@ struct DynamicFreeze
return true;
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<LowerIndex>::value;
}
__host__ __device__ void Print() const
{
printf("DynamicFreeze");
......
......@@ -201,6 +201,19 @@ struct DynamicTensorDescriptor
return VisibleDimensionIds{};
}
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
bool is_known = true;
static_for<0, Transforms::Size(), 1>{}([&](auto i) {
is_known &=
remove_cv_t<remove_reference_t<decltype(Transforms{}[i])>>::IsKnownAtCompileTime();
});
return is_known && is_known_at_compile_time<ElementSize>::value &&
is_known_at_compile_time<ElementSpaceSize>::value;
}
__host__ __device__ void Print() const
{
printf("{");
......
......@@ -330,360 +330,5 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
};
// blockwise GEMM: C += transpose(A) * B
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemm_km_kn_m0m1n0n1_v1
{
struct MatrixIndex
{
index_t row;
index_t col;
};
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row));
mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col));
}
__device__ static constexpr auto GetThreadMatrixCLengths()
{
constexpr auto I1 = Number<1>{};
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
constexpr index_t NRepeat =
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1ThreadCluster;
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0ThreadCluster;
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA,
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx)>{};
#pragma unroll
// loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
// read A
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset(
make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) +
mMyThreadOffsetA,
p_a_thread + a_thread_mtx.CalculateOffset(
make_tuple(0, m_repeat * MPerThreadSubC)));
}
#pragma unroll
// read B
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
b_thread_copy.Run(p_b_block +
b_block_mtx.CalculateOffset(
make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) +
mMyThreadOffsetB,
p_b_thread + b_thread_mtx.CalculateOffset(
make_tuple(0, n_repeat * NPerThreadSubC)));
}
// C += A * B
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
}
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2,
"wrong! inline asm cannot deal with this GEMM config yet");
// thread A, B
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{}));
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}),
make_tuple(Number<MPerThread>{}, Number<1>{}));
constexpr auto b_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
constexpr auto c_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA,
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
// read A_sub_0
a_thread_copy.Run(p_a_block_off, p_a_thread);
// read B_sub_0
b_thread_copy.Run(p_b_block_off, p_b_thread);
// read B_sub_1
b_thread_copy.Run(p_b_block_off +
b_block_mtx.CalculateOffset(make_tuple(0, NPerLevel1Cluster)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// read A_sub_1
a_thread_copy.Run(p_a_block_off +
a_block_mtx.CalculateOffset(make_tuple(0, MPerLevel1Cluster)),
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
#pragma unroll
// loop over rest of k
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
{
// read A_sub_0
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)),
p_a_thread);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// read B_sub_0
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, 0)),
p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread +
c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
// read B_sub_1
b_thread_copy.Run(
p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, NPerLevel1Cluster)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// read A_sub_1
a_thread_copy.Run(
p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, MPerLevel1Cluster)),
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0);
constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1);
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
}
else
{
Run_naive(p_a_block, p_b_block, p_c_thread);
}
#else
Run_naive(p_a_block, p_b_block, p_c_thread);
#endif
}
};
} // namespace ck
#endif
#ifndef CK_BLOCKWISE_GEMM_V2_HPP
#define CK_BLOCKWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "threadwise_gemm_v2.hpp"
namespace ck {
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template <index_t BlockSize,
typename BlockMatrixA,
typename BlockMatrixB,
typename ThreadMatrixC,
index_t MPerThreadSubC,
index_t NPerThreadSubC,
index_t KPerThreadLoop,
index_t MLevel0ThreadCluster,
index_t NLevel0ThreadCluster,
index_t MLevel1ThreadCluster,
index_t NLevel1ThreadCluster,
index_t ThreadGemmADataPerRead_M,
index_t ThreadGemmBDataPerRead_N>
struct BlockwiseGemm_km_kn_m0m1n0n1_v1
{
struct MatrixIndex
{
index_t row;
index_t col;
};
index_t mMyThreadOffsetA;
index_t mMyThreadOffsetB;
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v1()
{
static_assert(BlockMatrixA::IsKnownAtCompileTime() &&
BlockMatrixB::IsKnownAtCompileTime() &&
ThreadMatrixC::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster *
MLevel1ThreadCluster * NLevel1ThreadCluster;
static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n");
static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0),
"wrong! K dimension not consistent\n");
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 &&
N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0,
"wrong! Cannot evenly divide work among\n");
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong");
auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row));
mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col));
}
__device__ static constexpr auto GetThreadMatrixCLengths()
{
constexpr auto I1 = Number<1>{};
constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed
constexpr index_t N = BlockMatrixB{}.GetLength(I1);
constexpr index_t MRepeat =
M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster);
constexpr index_t NRepeat =
N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster);
return Sequence<MRepeat * MPerThreadSubC, NRepeat * NPerThreadSubC>{};
}
__device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id)
{
constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster;
index_t level1_id = thread_id / ThreadPerLevel0Cluster;
index_t level1_m_id = level1_id / NLevel1ThreadCluster;
index_t level1_n_id = level1_id % NLevel1ThreadCluster;
index_t level0_id = thread_id % ThreadPerLevel0Cluster;
index_t level0_m_id = level0_id / NLevel0ThreadCluster;
index_t level0_n_id = level0_id % NLevel0ThreadCluster;
constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster;
constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster;
return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC,
level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC};
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
// thread A, B for GEMM
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
Number<KPerThreadLoop>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA,
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_mtx),
decltype(b_thread_mtx),
decltype(c_thread_mtx)>{};
#pragma unroll
// loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{
#pragma unroll
// read A
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{
a_thread_copy.Run(p_a_block +
a_block_mtx.CalculateOffset(
make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) +
mMyThreadOffsetA,
p_a_thread + a_thread_mtx.CalculateOffset(
make_tuple(0, m_repeat * MPerThreadSubC)));
}
#pragma unroll
// read B
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
{
b_thread_copy.Run(p_b_block +
b_block_mtx.CalculateOffset(
make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) +
mMyThreadOffsetB,
p_b_thread + b_thread_mtx.CalculateOffset(
make_tuple(0, n_repeat * NPerThreadSubC)));
}
// C += A * B
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
}
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr auto K = a_block_mtx.GetLength(I0);
constexpr auto MPerThread = c_thread_mtx.GetLength(I0);
constexpr auto NPerThread = c_thread_mtx.GetLength(I1);
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MRepeat == 2 && NRepeat == 2,
"wrong! inline asm cannot deal with this GEMM config yet");
// thread A, B
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThread>{}));
constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThread>{}));
// thread A-sub, B-sub
constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}),
make_tuple(Number<MPerThread>{}, Number<1>{}));
constexpr auto b_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
constexpr auto c_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2(
make_tuple(Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}),
make_tuple(Number<NPerThread>{}, Number<1>{}));
FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()];
FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()];
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixA,
decltype(a_thread_mtx),
KPerThreadLoop,
MPerThreadSubC,
ThreadGemmADataPerRead_M>{};
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2<BlockMatrixB,
decltype(b_thread_mtx),
KPerThreadLoop,
NPerThreadSubC,
ThreadGemmBDataPerRead_N>{};
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1<decltype(a_thread_sub_mtx),
decltype(b_thread_sub_mtx),
decltype(c_thread_sub_mtx)>{};
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
// read A_sub_0
a_thread_copy.Run(p_a_block_off, p_a_thread);
// read B_sub_0
b_thread_copy.Run(p_b_block_off, p_b_thread);
// read B_sub_1
b_thread_copy.Run(p_b_block_off +
b_block_mtx.CalculateOffset(make_tuple(0, NPerLevel1Cluster)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// read A_sub_1
a_thread_copy.Run(p_a_block_off +
a_block_mtx.CalculateOffset(make_tuple(0, MPerLevel1Cluster)),
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
#pragma unroll
// loop over rest of k
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
{
// read A_sub_0
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)),
p_a_thread);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// read B_sub_0
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, 0)),
p_b_thread);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread +
c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
// read B_sub_1
b_thread_copy.Run(
p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, NPerLevel1Cluster)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
// read A_sub_1
a_thread_copy.Run(
p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, MPerLevel1Cluster)),
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm.Run(
p_a_thread,
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread,
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm.Run(
p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)),
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)),
p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC)));
}
template <typename FloatA, typename FloatB, typename FloatC>
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0);
constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1);
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
if constexpr(MRepeat == 2 && NRepeat == 2)
{
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
}
else
{
Run_naive(p_a_block, p_b_block, p_c_thread);
}
#else
Run_naive(p_a_block, p_b_block, p_c_thread);
#endif
}
};
} // namespace ck
#endif
......@@ -2,13 +2,12 @@
#define CK_GRIDWISE_DYNAMIC_GEMM_HPP
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_gemm.hpp"
#include "blockwise_gemm_v2.hpp"
namespace ck {
......@@ -397,6 +396,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
constexpr auto tmp = make_unmerge_transform(make_tuple(
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3<
AccFloat,
Float,
......
......@@ -44,7 +44,8 @@ template <typename SrcData,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp,
index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun>
bool DstResetCoordinateAfterRun,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{
static constexpr index_t nDim = SliceLengths::Size();
......@@ -59,6 +60,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
: dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx))
{
static_assert(SrcDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
......@@ -342,7 +345,8 @@ template <typename SrcData,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun>
bool SrcResetCoordinateAfterRun,
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseDynamicTensorSliceTransfer_v2
{
static constexpr index_t nDim = SliceLengths::Size();
......@@ -357,6 +361,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
const Index& src_slice_origin_idx)
: src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx))
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
}
__device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
......@@ -1233,9 +1239,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
private:
static constexpr auto buffer_desc_ =
make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(SliceLengths{}));
make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{}));
static constexpr index_t buffer_size_ = buffer_desc_.GetElementSpaceSize();
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
StaticallyIndexedArray<SrcData, buffer_size_> buffer_;
......
......@@ -161,160 +161,5 @@ struct ThreadwiseGemmTransANormalBNormalC
}
};
template <typename Float, class Matrix>
__device__ void threadwise_matrix_set_zero_v2(Matrix, Float* __restrict__ p_thread)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto M = Matrix{}.GetLength(I0);
constexpr auto N = Matrix{}.GetLength(I1);
static_for<0, M, 1>{}([&](auto i) {
static_for<0, N, 1>{}([&](auto j) {
constexpr auto offset = Matrix{}.CalculateOffset(make_tuple(i, j));
p_thread[offset] = Float(0);
});
});
}
template <typename SrcMatrix,
typename DstMatrix,
index_t NSliceRow,
index_t NSliceCol,
index_t DataPerAccess>
struct ThreadwiseMatrixSliceCopy_v2
{
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
using vector_t = typename vector_type<Data, DataPerAccess>::MemoryType;
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
constexpr auto src_offset = SrcMatrix{}.CalculateOffset(make_tuple(i, j));
constexpr auto dst_offset = DstMatrix{}.CalculateOffset(make_tuple(i, j));
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
});
});
}
};
// C += transpose(A) * B
// Element of matrix can be vectorized data
template <typename MatrixA, typename MatrixB, typename MatrixC>
struct ThreadwiseGemm_km_kn_mn_v1
{
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr index_t M = MatrixC{}[I0];
constexpr index_t N = MatrixC{}[I1];
constexpr index_t K = MatrixA{}[I0];
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
const index_t a_offset =
MatrixA{}.CalculateOffset(make_tuple(k, m)); // A is transposed
const index_t b_offset = MatrixB{}.CalculateOffset(make_tuple(k, n));
const index_t c_offset = MatrixC{}.CalculateOffset(make_tuple(m, n));
p_c[c_offset] +=
inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr index_t M = MatrixC{}.GetLength(I0);
constexpr index_t N = MatrixC{}.GetLength(I1);
constexpr index_t K = MatrixA{}.GetLength(I0);
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
constexpr auto a_offset = MatrixA{}.CalculateOffset(make_tuple(k, m));
if constexpr(N == 2)
{
constexpr auto b_offset_0 = MatrixB{}.CalculateOffset(make_tuple(k, I0));
constexpr auto b_offset_1 = MatrixB{}.CalculateOffset(make_tuple(k, I1));
constexpr auto c_offset_0 = MatrixC{}.CalculateOffset(make_tuple(m, I0));
constexpr auto c_offset_1 = MatrixC{}.CalculateOffset(make_tuple(m, I1));
amd_assembly_outer_product_1x2(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_c[c_offset_0],
p_c[c_offset_1]);
}
else if constexpr(N == 4)
{
constexpr auto b_offset_0 = MatrixB{}.CalculateOffset(make_tuple(k, I0));
constexpr auto b_offset_1 = MatrixB{}.CalculateOffset(make_tuple(k, I1));
constexpr auto b_offset_2 = MatrixB{}.CalculateOffset(make_tuple(k, I2));
constexpr auto b_offset_3 = MatrixB{}.CalculateOffset(make_tuple(k, I3));
constexpr auto c_offset_0 = MatrixC{}.CalculateOffset(make_tuple(m, I0));
constexpr auto c_offset_1 = MatrixC{}.CalculateOffset(make_tuple(m, I1));
constexpr auto c_offset_2 = MatrixC{}.CalculateOffset(make_tuple(m, I2));
constexpr auto c_offset_3 = MatrixC{}.CalculateOffset(make_tuple(m, I3));
amd_assembly_outer_product_1x4(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_b[b_offset_2],
p_b[b_offset_3],
p_c[c_offset_0],
p_c[c_offset_1],
p_c[c_offset_2],
p_c[c_offset_3]);
}
});
});
}
#endif
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr bool has_amd_asm = is_same<FloatC, float>{} &&
((is_same<FloatA, float>{} && is_same<FloatB, float>{}) ||
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
if constexpr(has_amd_asm)
{
Run_amd_asm(p_a, p_b, p_c);
}
else
{
Run_source(p_a, p_b, p_c);
}
#else
Run_source(p_a, p_b, p_c);
#endif
}
};
} // namespace ck
#endif
#ifndef CK_THREADWISE_GEMM_V2_HPP
#define CK_THREADWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace ck {
template <typename Float, typename Desc>
__device__ void threadwise_matrix_set_zero_v2(Desc, Float* __restrict__ p_thread)
{
static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto desc = Desc{};
constexpr auto M = desc.GetLength(I0);
constexpr auto N = desc.GetLength(I1);
static_for<0, M, 1>{}([&](auto i) {
static_for<0, N, 1>{}([&](auto j) {
constexpr auto offset = desc.CalculateOffset(make_tuple(i, j));
p_thread[offset] = Float(0);
});
});
}
template <typename SrcDesc,
typename DstDesc,
index_t NSliceRow,
index_t NSliceCol,
index_t DataPerAccess>
struct ThreadwiseMatrixSliceCopy_v2
{
template <typename Data>
__device__ static void Run(const Data* p_src, Data* p_dst)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
using vector_t = typename vector_type<Data, DataPerAccess>::MemoryType;
static_for<0, NSliceRow, 1>{}([&](auto i) {
static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) {
constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j));
constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j));
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) =
*reinterpret_cast<const vector_t*>(&p_src[src_offset]);
});
});
}
};
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
template <typename ADesc,
typename BDesc,
typename CDesc,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseGemm_km_kn_mn_v1
{
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto M = CDesc{}[I0];
constexpr auto N = CDesc{}[I1];
constexpr auto K = ADesc{}[I0];
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
static_for<0, N, 1>{}([&](auto n) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m));
constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(k, n));
constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(m, n));
p_c[c_offset] +=
inner_product_with_conversion<FloatC>{}(p_a[a_offset], p_b[b_offset]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
constexpr auto M = CDesc{}.GetLength(I0);
constexpr auto N = CDesc{}.GetLength(I1);
constexpr auto K = ADesc{}.GetLength(I0);
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
static_for<0, K, 1>{}([&](auto k) {
static_for<0, M, 1>{}([&](auto m) {
constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m));
if constexpr(N == 2)
{
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1));
amd_assembly_outer_product_1x2(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_c[c_offset_0],
p_c[c_offset_1]);
}
else if constexpr(N == 4)
{
constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0));
constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1));
constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(k, I2));
constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(k, I3));
constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0));
constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1));
constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(m, I2));
constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(m, I3));
amd_assembly_outer_product_1x4(p_a[a_offset],
p_b[b_offset_0],
p_b[b_offset_1],
p_b[b_offset_2],
p_b[b_offset_3],
p_c[c_offset_0],
p_c[c_offset_1],
p_c[c_offset_2],
p_c[c_offset_3]);
}
});
});
}
#endif
template <typename FloatA, typename FloatB, typename FloatC>
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr bool has_amd_asm = is_same<FloatC, float>{} &&
((is_same<FloatA, float>{} && is_same<FloatB, float>{}) ||
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
if constexpr(has_amd_asm)
{
Run_amd_asm(p_a, p_b, p_c);
}
else
{
Run_source(p_a, p_b, p_c);
}
#else
Run_source(p_a, p_b, p_c);
#endif
}
};
} // namespace ck
#endif
......@@ -279,5 +279,18 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
}
template <index_t... Is>
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
{
using Seq = Sequence<Is...>;
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = Seq::At(i);
return Number<tmp>{};
},
Seq::Size());
}
} // namespace ck
#endif
......@@ -6,6 +6,24 @@
namespace ck {
template <typename... Ts>
struct is_known_at_compile_time<Tuple<Ts...>>
{
__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return container_reduce(
Tuple<Ts...>{},
[](auto x, bool r) {
return is_known_at_compile_time<
remove_cv_t<remove_reference_t<decltype(x)>>>::value &
r;
},
true);
}
static constexpr bool value = IsKnownAtCompileTime();
};
template <typename F, index_t N>
__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
{
......
......@@ -27,5 +27,20 @@ constexpr std::remove_reference_t<T>&& move(T&& t) noexcept
return static_cast<typename std::remove_reference<T>::type&&>(t);
}
template <typename T>
struct is_known_at_compile_time;
template <>
struct is_known_at_compile_time<index_t>
{
static constexpr bool value = false;
};
template <typename T, T X>
struct is_known_at_compile_time<integral_constant<T, X>>
{
static constexpr bool value = true;
};
} // namespace ck
#endif
......@@ -3,19 +3,6 @@
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template <typename T>
__host__ __device__ constexpr auto sequence_to_tuple_of_number(const T& x)
{
using namespace ck;
return generate_tuple(
[&](auto i) {
constexpr index_t tmp = T::At(i);
return Number<tmp>{};
},
T::Size());
}
template <class T,
class InDesc,
class WeiDesc,
......@@ -269,7 +256,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr auto conv_driver =
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
#elif 0
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment