"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c1184918c5cb79e816d5e5e18c023e8d671fd9e8"
Commit f7498d66 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed conflict

parents 5fbf4f33 d6d9a8e4
...@@ -190,8 +190,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -190,8 +190,8 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
#elif 1 #elif 0
// 1x1, 14x14, Vega 10 // 1x1, 14x14, Vega 20
constexpr index_t BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 8; constexpr index_t CPerBlock = 8;
...@@ -219,6 +219,36 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -219,6 +219,36 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128;
#elif 1
// 1x1, 14x14, Vega 20, hack CPerBlock = 1
constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 1;
constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmThreadPerColumnPerCluster = 8;
constexpr index_t GemmThreadPerRowPerCluster = 8;
constexpr index_t InBlockCopyThreadPerDim0 = 4;
constexpr index_t InBlockCopyThreadPerDim1 = 16;
constexpr index_t WeiBlockCopyThreadPerDim0 = 4;
constexpr index_t WeiBlockCopyThreadPerDim1 = 16;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#endif #endif
......
...@@ -477,9 +477,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -477,9 +477,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
} }
template <class FloatA, class FloatB, class FloatC, class Accumulator> template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void Run_asm(const FloatA* __restrict__ p_a_block, __device__ void Run_asm(const FloatA* const __restrict__ p_a_block,
const FloatB* __restrict__ p_b_block, const FloatB* const __restrict__ p_b_block,
FloatC* __restrict__ p_c_thread, FloatC* const __restrict__ p_c_thread,
Accumulator f_accum) const Accumulator f_accum) const
{ {
constexpr auto True = integral_constant<bool, true>{}; constexpr auto True = integral_constant<bool, true>{};
...@@ -519,11 +519,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -519,11 +519,18 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
constexpr index_t NRepeat = NPerThread / NPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 &&
KPerThreadLoop == 1 && K == 1,
"asm is not for this mtx shape");
const FloatA* const p_a_block_thread_offset = p_a_block + mMyThreadOffsetA;
#pragma unroll #pragma unroll
// loop over k // loop over k
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
{ {
//#pragma unroll #if 0
#pragma unroll
// copy A-sub to form A // copy A-sub to form A
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
{ {
...@@ -532,9 +539,65 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -532,9 +539,65 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) + p_a_block + a_block_mtx.Get1dIndex(k_begin, m_repeat * MPerLevel1Cluster) +
mMyThreadOffsetA, mMyThreadOffsetA,
a_thread_mtx, a_thread_mtx,
p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC), a_thread_sub_mtx.NCol(p_a_thread + a_thread_mtx.Get1dIndex(0, m_repeat * MPerThreadSubC),
a_thread_sub_mtx.GetLengths()); a_thread_sub_mtx.GetLengths());
} }
#elif 1
// this produce right result
using vectorA_t = typename vector_type<FloatA, 4>::MemoryType; // this is float4*
asm volatile(
"\n \
ds_read_b128 %0, %1 \n \
s_waitcnt lgkmcnt(0)"
: "=v"(*(reinterpret_cast<vectorA_t*>(p_a_thread + a_thread_mtx.Get1dIndex(0, 0))))
: "v"(__to_local(
(void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA))));
asm volatile("\n \
ds_read_b128 %0, %1 \n \
s_waitcnt lgkmcnt(0)"
: "=v"(*(reinterpret_cast<vectorA_t*>(
p_a_thread + a_thread_mtx.Get1dIndex(0, MPerThreadSubC))))
: "v"(__to_local((
void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, MPerLevel1Cluster) +
mMyThreadOffsetA))));
#elif 0
// this produce wrong result
using vectorA_t = typename vector_type<FloatA, 4>::MemoryType; // this is float4*
asm volatile(
"\n \
ds_read_b128 %0, %2 \n \
ds_read_b128 %1, %3 \n \
s_waitcnt lgkmcnt(0)"
: "=v"(*(reinterpret_cast<vectorA_t*>(p_a_thread + a_thread_mtx.Get1dIndex(0, 0)))),
"=v"(*(reinterpret_cast<vectorA_t*>(p_a_thread +
a_thread_mtx.Get1dIndex(0, MPerThreadSubC))))
: "v"(__to_local(
(void*)(p_a_block + a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA))),
"v"(__to_local((void*)(p_a_block +
a_block_mtx.Get1dIndex(k_begin, MPerLevel1Cluster) +
mMyThreadOffsetA))));
#elif 1
// this produce wrong result
using vectorA_t = typename vector_type<FloatA, 4>::MemoryType; // this is float4*
asm volatile(
"\n \
ds_read_b128 %0, %1 \n \
s_waitcnt lgkmcnt(0)"
: "=v"(*(reinterpret_cast<vectorA_t*>(p_a_thread + a_thread_mtx.Get1dIndex(0, 0))))
: "v"(__to_local((void*)(p_a_block_thread_offset))));
asm volatile("\n \
ds_read_b128 %0, %1 offset:16 \n \
s_waitcnt lgkmcnt(0)"
: "=v"(*(reinterpret_cast<vectorA_t*>(
p_a_thread + a_thread_mtx.Get1dIndex(0, MPerThreadSubC))))
: "v"(__to_local((void*)(p_a_block_thread_offset))));
#endif
//#pragma unroll //#pragma unroll
// copy B-sub to form B // copy B-sub to form B
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include "Array.hip.hpp" #include "Array.hip.hpp"
#include "functional.hip.hpp" #include "functional.hip.hpp"
extern "C" __attribute__((address_space(3))) void* __to_local(void* p)[[hc]];
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_thread_local_1d_id() { return threadIdx.x; }
__device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; }
......
...@@ -238,7 +238,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric ...@@ -238,7 +238,7 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn(const Float* const __restric
auto f_accum = [](auto& acc, const auto&& v) { acc += v; }; auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
#if 1 #if 1
blockwise_gemm.Run blockwise_gemm.Run
#elif 0 #elif 1
blockwise_gemm.Run_asm blockwise_gemm.Run_asm
#elif 1 #elif 1
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
......
...@@ -289,10 +289,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -289,10 +289,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#else #else
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#endif #endif
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + y * Wi + x, p_in_block_now + y * Wi + x,
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
} }
...@@ -319,10 +319,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer( ...@@ -319,10 +319,10 @@ gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer(
#else #else
blockwise_gemm.Run_RegisterDoubleBuffer blockwise_gemm.Run_RegisterDoubleBuffer
#endif #endif
(p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0), (p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
p_in_block_now + y * Wi + x, p_in_block_now + y * Wi + x,
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
} }
} }
......
...@@ -23,7 +23,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -23,7 +23,7 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
p_dst[dst_index] = p_src[src_index]; p_dst[dst_index] = p_src[src_index];
} }
} }
#elif 1 #elif 0
static_assert(NCol == 4, "only for NCol == 4"); static_assert(NCol == 4, "only for NCol == 4");
using vector_t = typename vector_type<Float, 4>::MemoryType; using vector_t = typename vector_type<Float, 4>::MemoryType;
...@@ -33,15 +33,21 @@ __device__ void threadwise_matrix_copy(SrcMatrix, ...@@ -33,15 +33,21 @@ __device__ void threadwise_matrix_copy(SrcMatrix,
const index_t src_index = src_mtx.Get1dIndex(i, 0); const index_t src_index = src_mtx.Get1dIndex(i, 0);
const index_t dst_index = dst_mtx.Get1dIndex(i, 0); const index_t dst_index = dst_mtx.Get1dIndex(i, 0);
#if 1 #if 0
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) = *(reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
*(reinterpret_cast<const vector_t*>(p_src + src_index)); *(reinterpret_cast<const vector_t*>(&p_src[src_index]));
#elif 0
asm volatile("\n \
ds_read2_b64 %0, %1 offset1:1 \n \
s_waitcnt lgkmcnt(0)"
: "=v"(*(reinterpret_cast<vector_t*>(&p_dst[dst_index])))
: "v"(__to_local((void*)(&p_src[src_index]))));
#elif 1 #elif 1
asm volatile("\n \ asm volatile("\n \
ds_read_b128 %0, %1, offset:0 \n \ ds_read_b128 %0, %1 \n \
" s_waitcnt lgkmcnt(0)"
: "=v"(*(reinterpret_cast<vector_t*>(p_dst+dst_index))) : "=v"(*(reinterpret_cast<vector_t*>(&p_dst[dst_index])))
: "v"((uint32_t)(p_src + src_index))); : "v"(__to_local((void*)(&p_src[src_index]))));
#endif #endif
} }
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment