"examples/vscode:/vscode.git/clone" did not exist on "f06e4e5579b2c43a35a1678dea9b9ec3d66d365a"
Commit 46a0aec1 authored by Chao Liu's avatar Chao Liu
Browse files

trying gemm asm

parent 2603bb0f
...@@ -580,13 +580,16 @@ int main(int argc, char* argv[]) ...@@ -580,13 +580,16 @@ int main(int argc, char* argv[])
#if 0 #if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 1 #elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
auto gen_wei = [](auto... is) { auto gen_wei = [](auto... is) {
......
...@@ -289,9 +289,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -289,9 +289,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
const FloatB* __restrict__ p_b_block, const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const FloatC* __restrict__ p_c_thread) const
{ {
constexpr auto True = integral_constant<bool, true>{};
constexpr auto False = integral_constant<bool, false>{};
constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{}; constexpr auto c_thread_mtx = ThreadMatrixC{};
...@@ -371,6 +368,102 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -371,6 +368,102 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
} }
template <class FloatA, class FloatB, class FloatC>
__device__ void Run_asm_v2(const FloatA* __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread) const
{
constexpr auto a_block_mtx = BlockMatrixA{};
constexpr auto b_block_mtx = BlockMatrixB{};
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t M = a_block_mtx.NCol();
constexpr index_t N = b_block_mtx.NCol();
constexpr index_t K = a_block_mtx.NRow(); // A is transposed
constexpr index_t MPerThread = c_thread_mtx.NRow();
constexpr index_t NPerThread = c_thread_mtx.NCol();
// thread A, B for GEMM
// A is transposed, b is not
constexpr auto a_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<MPerThread>{});
constexpr auto b_thread_mtx =
make_ConstantMatrixDescriptor(Number<KPerThreadLoop>{}, Number<NPerThread>{});
// thread A-sub, B-sub for copy
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
// assertion for inline asm
static_assert(is_same<FloatA, float>::value && is_same<FloatB, float>::value &&
is_same<FloatC, float>::value,
"Run_asm only deal with float\n");
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
MPerThread == 8 && NPerThread == 8,
"Run_asm cannot deal with this GEMM shape yet\n");
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");
static_assert(
BlockMatrixStrideA == 0 && BatchPerThread == 1,
"Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");
using Float4 = vector_type<float, 4>::MemoryType;
Float4* reg_a = (Float4*)(p_a_thread);
Float4* reg_b = (Float4*)(p_b_thread);
Float4* reg_c = (Float4*)(p_c_thread);
void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA);
void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB);
constexpr index_t a_lds_row_stride = sizeof(Float) * M;
constexpr index_t b_lds_row_stride = sizeof(Float) * N;
constexpr index_t a_lds_cluster_col_stride = sizeof(Float) * MPerLevel1Cluster;
constexpr index_t b_lds_cluster_col_stride = sizeof(Float) * NPerLevel1Cluster;
ds_read_b128(reg_a[0], a_lds_loc, 0);
ds_read_b128(reg_b[0], b_lds_loc, 0);
ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride);
ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride);
lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
#pragma unroll
for(index_t k = 1; k < K; ++k)
{
ds_read_b128(reg_a[0], a_lds_loc, k * a_lds_row_stride);
lgkmcnt(1);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
ds_read_b128(reg_b[0], b_lds_loc, k * b_lds_row_stride);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride + k * b_lds_row_stride);
ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride + k * a_lds_row_stride);
lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1);
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
}
lgkmcnt(0);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
}
#endif #endif
template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC> template <class BlockMatrixC, index_t BlockMatrixStrideC, class FloatC>
......
...@@ -273,7 +273,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn ...@@ -273,7 +273,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
__syncthreads(); __syncthreads();
#if 1
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
#elif 0
blockwise_batch_gemm.Run_asm(p_wei_block, p_in_block, p_out_thread);
#elif 0
blockwise_batch_gemm.Run_asm_v2(p_wei_block, p_in_block, p_out_thread);
#endif
__syncthreads(); __syncthreads();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment