"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4e125f72ab7670b9041618f949871f016c62a904"
Commit 6a3f3f95 authored by Jing Zhang's avatar Jing Zhang
Browse files

add

parent b188c0d2
...@@ -224,7 +224,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -224,7 +224,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
// 1x1, 14x14, Vega 20, hack CPerBlock = 1 // 1x1, 14x14, Vega 20, hack CPerBlock = 1
constexpr index_t BPerBlock = 64; constexpr index_t BPerBlock = 64;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
constexpr index_t CPerBlock = 1; constexpr index_t CPerBlock = 8;
constexpr index_t BPerThread = 8; constexpr index_t BPerThread = 8;
constexpr index_t KPerThread = 8; constexpr index_t KPerThread = 8;
...@@ -232,7 +232,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -232,7 +232,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t GemmKPerThreadLoop = 1;
...@@ -249,7 +249,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, ...@@ -249,7 +249,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc,
constexpr index_t InBlockCopyDataPerRead = 4; constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4; constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 256;
#endif #endif
constexpr index_t GridSize = constexpr index_t GridSize =
......
...@@ -580,7 +580,7 @@ int main(int argc, char* argv[]) ...@@ -580,7 +580,7 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0; constexpr index_t HPad = 0;
constexpr index_t WPad = 0; constexpr index_t WPad = 0;
#elif 0 #elif 1
// 1x1 filter, 14x14 image, C = 2048 // 1x1 filter, 14x14 image, C = 2048
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 2048; constexpr index_t C = 2048;
......
...@@ -404,10 +404,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -404,10 +404,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
#else #else
int k = k_begin; int k = k_begin;
ds_read_b128(reg_a[0], a_loc, k * 512); int lds_a_block_off = sizeof(Float) * M;
ds_read_b128(reg_b[0], b_loc, k * 256); int lds_b_block_off = sizeof(Float) * N;
ds_read_b128(reg_b[1], b_loc, 128 + k * 256); int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float);
ds_read_b128(reg_a[1], a_loc, 256 + k * 512); int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float);
ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off);
ds_read_b128(reg_b[0], b_loc, k * lds_b_block_off);
ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k * lds_b_block_off);
ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k * lds_a_block_off);
lgkmcnt(2); lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1); lgkmcnt(1);
...@@ -416,12 +420,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -416,12 +420,12 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
for(int i = 0; i < k_chunk - 1; i++) for(int i = 0; i < k_chunk - 1; i++)
{ {
k = k + 1; k = k + 1;
ds_read_b128(reg_a[0], a_loc, k * 512); ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off);
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
ds_read_b128(reg_b[0], b_loc, k * 256); ds_read_b128(reg_b[0], b_loc, k * lds_b_block_off);
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
ds_read_b128(reg_b[1], b_loc, 128 + k * 256); ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k * lds_b_block_off);
ds_read_b128(reg_a[1], a_loc, 256 + k * 512); ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k * lds_a_block_off);
lgkmcnt(2); lgkmcnt(2);
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
lgkmcnt(1); lgkmcnt(1);
......
...@@ -297,8 +297,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn ...@@ -297,8 +297,7 @@ class gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), __syncthreads())
__syncthreads())
{ {
// load data // load data
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment