"...composable_kernel-1.git" did not exist on "6fe3627a9eb35f1237266f1b6cc8fd3456aed67d"
Commit b5b4fd28 authored by Chao Liu's avatar Chao Liu
Browse files

refactor

parent c64f63d5
...@@ -354,10 +354,10 @@ int main() ...@@ -354,10 +354,10 @@ int main()
{ {
#if 0 #if 0
constexpr unsigned N = 1; constexpr unsigned N = 1;
constexpr unsigned C = 2; constexpr unsigned C = 1;
constexpr unsigned HI = 34; constexpr unsigned HI = 34;
constexpr unsigned WI = 34; constexpr unsigned WI = 34;
constexpr unsigned K = 2; constexpr unsigned K = 4;
constexpr unsigned S = 3; constexpr unsigned S = 3;
constexpr unsigned R = 3; constexpr unsigned R = 3;
#elif 1 #elif 1
...@@ -418,7 +418,7 @@ int main() ...@@ -418,7 +418,7 @@ int main()
device_direct_convolution_2 device_direct_convolution_2
#elif 0 #elif 0
device_implicit_gemm_convolution_1_nchw_kcsr device_implicit_gemm_convolution_1_nchw_kcsr
#elif 1 #elif 0
device_implicit_gemm_convolution_1_nchw_srck_nkhw device_implicit_gemm_convolution_1_nchw_srck_nkhw
#elif 1 #elif 1
device_implicit_gemm_convolution_2_cnhw_srck_knhw device_implicit_gemm_convolution_2_cnhw_srck_knhw
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcsr, const Tensor<T>& wei_kcsr,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
unsigned nrepeat) unsigned nrepeat)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{}; constexpr auto I1 = Number<1>{};
...@@ -104,7 +104,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, ...@@ -104,7 +104,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
constexpr unsigned WoPerThread = 1; constexpr unsigned WoPerThread = 1;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#elif 1 #elif 0
constexpr unsigned NPerBlock = 2; constexpr unsigned NPerBlock = 2;
constexpr unsigned KPerBlock = 32; constexpr unsigned KPerBlock = 32;
constexpr unsigned CPerBlock = 4; constexpr unsigned CPerBlock = 4;
...@@ -137,20 +137,20 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, ...@@ -137,20 +137,20 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
cudaEventRecord(start, 0); cudaEventRecord(start, 0);
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw<GridSize, gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw<GridSize,
BlockSize, BlockSize,
T, T,
decltype(in_nchw_desc), decltype(in_nchw_desc),
decltype(wei_srck_desc), decltype(wei_srck_desc),
decltype(out_nkhw_desc), decltype(out_nkhw_desc),
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
CPerBlock, CPerBlock,
HoPerBlock, HoPerBlock,
WoPerBlock, WoPerBlock,
KPerThread, KPerThread,
CPerThread, CPerThread,
HoPerThread, HoPerThread,
WoPerThread> WoPerThread>
<<<grid_dim, block_dim>>>(in_nchw_desc, <<<grid_dim, block_dim>>>(in_nchw_desc,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
wei_srck_desc, wei_srck_desc,
...@@ -165,10 +165,9 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc, ...@@ -165,10 +165,9 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
cudaEventElapsedTime(&elapsedTime, start, stop); cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Elapsed time : %f ms\n", elapsedTime); printf("Elapsed time : %f ms\n", elapsedTime);
usleep(10); usleep(10000);
} }
checkCudaErrors(cudaGetLastError()); checkCudaErrors(cudaGetLastError());
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
} }
#pragma once #pragma once
#include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh" #include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
#include <unistd.h>
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class T, class InDesc, class WeiDesc, class OutDesc>
void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
...@@ -67,35 +68,29 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, ...@@ -67,35 +68,29 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
#if 0 #if 0
constexpr unsigned BPerBlock = 128; constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 1; constexpr unsigned KPerBlock = 4;
constexpr unsigned CPerBlock = 1; constexpr unsigned CPerBlock = 1;
constexpr unsigned BPerThread = 4; constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 1; constexpr unsigned KPerThread = 1;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 1;
constexpr unsigned BlockSize = 32; constexpr unsigned ThreadPerClusterRow = 4;
#elif 0 constexpr unsigned ThreadPerClusterColumn = 16;
constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 2;
constexpr unsigned CPerBlock = 2;
constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 2;
constexpr unsigned CPerThread = 1;
constexpr unsigned BlockSize = 32; constexpr unsigned BlockSize = 128;
#elif 1 #elif 1
constexpr unsigned BPerBlock = 128; constexpr unsigned BPerBlock = 128;
constexpr unsigned KPerBlock = 64; constexpr unsigned KPerBlock = 64;
constexpr unsigned CPerBlock = 2; constexpr unsigned CPerBlock = 2;
constexpr unsigned BPerBatch = 32;
constexpr unsigned BPerThread = 4; constexpr unsigned BPerThread = 4;
constexpr unsigned KPerThread = 16; constexpr unsigned KPerThread = 16;
constexpr unsigned CPerThread = 1; constexpr unsigned CPerThread = 1;
constexpr unsigned ThreadPerClusterRow = 4;
constexpr unsigned ThreadPerClusterColumn = 16;
constexpr unsigned BlockSize = 128; constexpr unsigned BlockSize = 128;
#endif #endif
...@@ -137,7 +132,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, ...@@ -137,7 +132,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
BPerThread, BPerThread,
KPerThread, KPerThread,
CPerThread, CPerThread,
BPerBatch> ThreadPerClusterRow,
ThreadPerClusterColumn>
<<<grid_dim, block_dim>>>(in_cnhw_desc, <<<grid_dim, block_dim>>>(in_cnhw_desc,
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
wei_srck_desc, wei_srck_desc,
...@@ -151,6 +147,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc, ...@@ -151,6 +147,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
cudaEventElapsedTime(&elapsedTime, start, stop); cudaEventElapsedTime(&elapsedTime, start, stop);
printf("Elapsed time : %f ms\n", elapsedTime); printf("Elapsed time : %f ms\n", elapsedTime);
usleep(10000);
} }
checkCudaErrors(cudaGetLastError()); checkCudaErrors(cudaGetLastError());
......
...@@ -156,11 +156,11 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c ...@@ -156,11 +156,11 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0"); static_assert(MPerBlock % MPerThread == 0, "MPerBlock % MPerThread != 0");
static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0"); static_assert(NPerBlock % NPerThread == 0, "NPerBlock % NPerThread != 0");
constexpr unsigned BThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread; constexpr unsigned BatchThreadWork = (BatchSize + BatchPerThread - 1) / BatchPerThread;
constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread; constexpr unsigned MThreadWork = (MPerBlock + MPerThread - 1) / MPerThread;
constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread;
static_assert(BlockSize == BThreadWork * MThreadWork * NThreadWork, static_assert(BlockSize == BatchThreadWork * MThreadWork * NThreadWork,
"wrong! wrong BlockSize"); "wrong! wrong BlockSize");
if(DistributeThreadAlongColumnFirst) if(DistributeThreadAlongColumnFirst)
...@@ -289,3 +289,205 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c ...@@ -289,3 +289,205 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
} }
} }
}; };
template <unsigned BlockSize,
class BlockMatrixA,
class BlockMatrixB,
class ThreadMatrixC,
bool TransA,
bool TransB,
bool TransC,
unsigned KPerThreadLoop,
unsigned MThreadPerCluster,
unsigned NThreadPerCluster,
bool DistributeThreadAlongColumnFirst>
struct blockwise_gemm_block_a_block_b_thread_c
{
unsigned mMyThreadOffsetA = 0;
unsigned mMyThreadOffsetB = 0;
struct MatrixIndex
{
unsigned row_begin;
unsigned col_begin;
};
__device__ blockwise_gemm_block_a_block_b_thread_c()
{
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx_index = CalculateThreadMatrixCIndex(get_thread_local_1d_id());
mMyThreadOffsetA = (!TransA) ? a_block_mtx.Get1dIndex(c_thread_mtx_index.row_begin, 0)
: a_block_mtx.Get1dIndex(0, c_thread_mtx_index.row_begin);
mMyThreadOffsetB = (!TransB) ? b_block_mtx.Get1dIndex(0, c_thread_mtx_index.col_begin)
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0);
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
printf("%u %u, %u %u %u, %u %u\n",
get_block_1d_id(),
get_thread_local_1d_id(),
c_thread_mtx_index.batch_begin,
c_thread_mtx_index.row_begin,
c_thread_mtx_index.col_begin,
mMyThreadOffsetA,
mMyThreadOffsetB);
}
#endif
}
__device__ MatrixIndex CalculateThreadMatrixCIndex(unsigned thread_id) const
{
if(TransA && (!TransB) && (!TransC))
{
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(),
"wrong! k dimension not consistent!");
constexpr unsigned MPerBlock = a_block_mtx.NCol();
constexpr unsigned NPerBlock = b_block_mtx.NCol();
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
// divide thread work
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
static_assert(MPerBlock % (MPerThread * MThreadPerCluster) == 0,
"MPerBlock % (MPerThread * MThreadPerCluster) != 0");
static_assert(NPerBlock % (NPerThread * NThreadPerCluster) == 0,
"NPerBlock % (NPerThread * NThreadPerCluster) != 0");
constexpr unsigned MClusterWork =
(MPerBlock + MPerThread * MThreadPerCluster - 1) / (MPerThread * MThreadPerCluster);
constexpr unsigned NClusterWork =
(NPerBlock + NPerThread * NThreadPerCluster - 1) / (NPerThread * NThreadPerCluster);
static_assert(BlockSize == (MClusterWork * MThreadPerCluster) *
(NClusterWork * NThreadPerCluster),
"wrong! wrong BlockSize");
if(DistributeThreadAlongColumnFirst)
{
const unsigned cluster_work_block_id =
thread_id / (MThreadPerCluster * NThreadPerCluster);
const unsigned thread_work_cluster_id =
thread_id - cluster_work_block_id * (MThreadPerCluster * NThreadPerCluster);
const unsigned m_cluster_work_block_id = cluster_work_block_id / NThreadPerCluster;
const unsigned n_cluster_work_block_id =
cluster_work_block_id - m_cluster_work_block_id * NThreadPerCluster;
const unsigned m_thread_work_cluster_id =
thread_work_cluster_id / NThreadPerCluster;
const unsigned n_thread_work_cluster_id =
thread_work_cluster_id - m_thread_work_cluster_id * NThreadPerCluster;
#if 0
if(get_block_1d_id() == 0)
{
printf("%u %u, \t"
//"MClusterWork %u MThreadPerCluster %u NClusterWork %u NThreadPerCluster %u \t"
"m_cluster_work_block_id %u n_cluster_work_block_id %u \t"
"m_thread_work_cluster_id %u n_thread_work_cluster_id %u \t"
"\n",
get_block_1d_id(), get_thread_local_1d_id(),
//MClusterWork, MThreadPerCluster, NClusterWork, NThreadPerCluster,
m_cluster_work_block_id, n_cluster_work_block_id,
m_thread_work_cluster_id, n_thread_work_cluster_id);
}
#endif
return MatrixIndex{m_cluster_work_block_id * (MThreadPerCluster * MPerThread) +
m_thread_work_cluster_id * MPerThread,
n_cluster_work_block_id * (NThreadPerCluster * NPerThread) +
n_thread_work_cluster_id * NPerThread};
}
else
{
// not implemented
assert(false);
}
}
else
{
// not implemented
assert(false);
}
}
template <class FloatA, class FloatB, class FloatC, class Accumulator>
__device__ void run(FloatA* const p_a_block,
FloatB* const p_b_block,
FloatC* p_c_thread,
Accumulator f_accum) const
{
if(TransA && (!TransB) && (!TransC))
{
constexpr auto True = Constant<bool, true>{};
constexpr auto False = Constant<bool, false>{};
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
const auto c_thread_mtx = ThreadMatrixC{}; // constexpr doesn't compile
constexpr unsigned KPerBlock = a_block_mtx.NRow(); // A is transposed
constexpr unsigned MPerThread = c_thread_mtx.NRow();
constexpr unsigned NPerThread = c_thread_mtx.NCol();
// a is transposed, b is not
const auto a_thread_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<MPerThread>{}); // constexpr doesn't compile
const auto b_thread_mtx = make_ConstantMatrixDescriptor(
Number<KPerThreadLoop>{}, Number<NPerThread>{}); // constexpr doesn't compile
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
// loop over k
for(unsigned k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop)
{
threadwise_matrix_copy(a_block_mtx,
p_a_block + mMyThreadOffsetA +
k_begin * a_block_mtx.RowStride(),
a_thread_mtx,
p_a_thread,
a_thread_mtx.GetLengths());
threadwise_matrix_copy(b_block_mtx,
p_b_block + mMyThreadOffsetB +
k_begin * b_block_mtx.RowStride(),
b_thread_mtx,
p_b_thread,
b_thread_mtx.GetLengths());
threadwise_gemm(a_thread_mtx,
True,
p_a_thread,
b_thread_mtx,
False,
p_b_thread,
c_thread_mtx,
False,
p_c_thread,
f_accum);
}
}
}
};
...@@ -23,11 +23,11 @@ template <unsigned GridSize, ...@@ -23,11 +23,11 @@ template <unsigned GridSize,
unsigned WoPerThread> unsigned WoPerThread>
__global__ void __global__ void
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc, gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
Float* const __restrict__ p_in_global, Float* const __restrict__ p_in_global,
WeiGlobalDesc, WeiGlobalDesc,
Float* const __restrict__ p_wei_global, Float* const __restrict__ p_wei_global,
OutGlobalDesc, OutGlobalDesc,
Float* __restrict__ p_out_global) Float* __restrict__ p_out_global)
{ {
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
......
...@@ -20,7 +20,8 @@ template <unsigned GridSize, ...@@ -20,7 +20,8 @@ template <unsigned GridSize,
unsigned BPerThread, unsigned BPerThread,
unsigned KPerThread, unsigned KPerThread,
unsigned CPerThread, unsigned CPerThread,
unsigned BPerBatch> unsigned ThreadPerClusterRow,
unsigned ThreadPerClusterColumn>
__global__ void __global__ void
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
Float* const __restrict__ p_in_global, Float* const __restrict__ p_in_global,
...@@ -112,31 +113,26 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -112,31 +113,26 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
static_assert(BPerBlock % BPerBatch == 0 && BPerBatch % BPerThread == 0, "B cannot be evenly divided\n");
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<CPerBlock>{}, Number<CPerBlock>{},
Number<BPerBatch>{}, Number<BPerBlock>{},
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor( const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
const auto blockwise_batched_gemm = const auto blockwise_gemm =
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize, blockwise_gemm_block_a_block_b_thread_c<BlockSize,
decltype(a_cxk_block_mtx_desc), decltype(a_cxk_block_mtx_desc),
decltype(b_cxb_block_mtx_desc), decltype(b_cxb_block_mtx_desc),
decltype(c_kxb_thread_mtx_desc), decltype(c_kxb_thread_mtx_desc),
true, true,
false, false,
false, false,
0, CPerThread,
BPerBatch, ThreadPerClusterRow,
0, ThreadPerClusterColumn,
BPerBlock/BPerBatch, true>{};
1,
CPerThread,
true>{};
// LDS // LDS
constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace(); constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace();
...@@ -175,6 +171,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -175,6 +171,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
__syncthreads(); __syncthreads();
#if 1
// a series of GEMM // a series of GEMM
for(unsigned s = 0; s < S; ++s) for(unsigned s = 0; s < S; ++s)
{ {
...@@ -182,31 +179,31 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -182,31 +179,31 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
{ {
auto f_accum = [](auto& c, const auto&& ab) { c += ab; }; auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
blockwise_batched_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), blockwise_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
p_in_block + s * Wi + r, p_in_block + s * Wi + r,
p_out_thread, p_out_thread,
f_accum); f_accum);
} }
} }
#endif
} }
// output: register to global mem, // output: register to global mem,
const auto matrix_c_index = const auto matrix_c_index =
blockwise_batched_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
const unsigned k_thread_data_begin = matrix_c_index.row_begin; const unsigned k_thread_data_begin = matrix_c_index.row_begin;
const unsigned b_thread_data_begin = matrix_c_index.batch_begin * BPerBatch + matrix_c_index.col_begin; const unsigned b_thread_data_begin = matrix_c_index.col_begin;
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin; const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin; const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
#if 0 #if 0
//if(get_block_1d_id() == 10) if(get_block_1d_id() == 0)
{ {
printf("%u %u, batch_begin %u row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n", printf("%u %u, row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
get_block_1d_id(), get_block_1d_id(),
get_thread_local_1d_id(), get_thread_local_1d_id(),
matrix_c_index.batch_begin,
matrix_c_index.row_begin, matrix_c_index.row_begin,
matrix_c_index.col_begin, matrix_c_index.col_begin,
k_data_begin, k_data_begin,
...@@ -228,7 +225,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc, ...@@ -228,7 +225,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
unsigned w_data = itmp - h_data * Wi; unsigned w_data = itmp - h_data * Wi;
#if 0 #if 0
if(get_block_1d_id() == 10) if(get_block_1d_id() == 0)
{ {
printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n", printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n",
get_block_1d_id(), get_block_1d_id(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment