"...resnet50_tensorflow.git" did not exist on "ffe51b3033ccfc31507b24e9abda106995a18ec4"
Commit 5fbf4f33 authored by Jing Zhang's avatar Jing Zhang
Browse files

inline

parent 2058bec8
#pragma once #pragma once
#include "threadwise_gemm.hip.hpp" #include "threadwise_gemm.hip.hpp"
extern "C" __attribute__((address_space(3))) void* __to_local(void* p) [[hc]];
template <index_t BlockSize, template <index_t BlockSize,
class BlockMatrixA, class BlockMatrixA,
class BlockMatrixB, class BlockMatrixB,
...@@ -387,9 +389,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -387,9 +389,11 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
{ {
threadwise_matrix_copy( threadwise_matrix_copy(
a_block_mtx, a_block_mtx,
//MPerLevel1Cluster = 4
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA, mMyThreadOffsetA,
a_thread_mtx, a_thread_mtx,
//MPerThreadSubC = 4
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths());
} }
...@@ -398,11 +402,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -398,11 +402,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
auto src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; auto src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA;
auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0); auto dst_index = a_thread_sub_mtx.Get1dIndex(0, 0);
const float4* loc = (const float4 *)(p_a_block + src_index); //const float4* loc = (const float4 *)(p_a_block + src_index);
float4* reg = (float4 *)(p_a_thread + dst_index); float4* reg = (float4 *)(p_a_thread + dst_index);
reg[0] = loc[0]; //reg[0] = loc[0];
reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4]; //reg[MPerThreadSubC/4] = loc[MPerLevel1Cluster/4];
asm volatile("\n \
ds_read2_b64 %0, %2 offset1:1 \n \
ds_read2_b64 %1, %2 offset0:16 offset1:17 \n \
s_waitcnt lgkmcnt(0)"
: "=v"(reg[0]), "=v"(reg[1])
: "v"(__to_local((void *)&p_a_block[src_index]))
);
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment