Commit 0f6cb787 authored by Chao Liu's avatar Chao Liu
Browse files

double index buffer

parent 7d09790a
...@@ -77,7 +77,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -77,7 +77,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
return 2 * (a_block_space + b_block_space) * sizeof(Float); return 2 * (a_block_space + b_block_space) * sizeof(Float);
} }
__device__ void Run(const Float* __restrict__ p_a_global, __device__ void Run_single_slice_window(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global, const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global, Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const Float* __restrict__ p_shared_block) const
...@@ -125,7 +125,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -125,7 +125,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{}); Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = using a_block_copy_type =
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(a_k_m_global_desc), decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc), decltype(a_k_m_block_desc),
...@@ -142,8 +142,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -142,8 +142,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace::Global, AddressSpace::Global,
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Lds, AddressSpace::Lds,
InMemoryDataOperation::Set>( InMemoryDataOperation::Set>;
{0, m_block_data_on_global}, {0, 0});
// B matrix in LDS memory, dst of blockwise copy // B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment // be careful of LDS alignment
...@@ -151,7 +150,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -151,7 +150,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{}); Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = using b_block_copy_type =
BlockwiseGenericTensorSliceCopy_v4<BlockSize, BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_k_n_global_desc), decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc), decltype(b_k_n_block_desc),
...@@ -168,8 +167,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -168,8 +167,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
AddressSpace::Global, AddressSpace::Global,
AddressSpace::Vgpr, AddressSpace::Vgpr,
AddressSpace::Lds, AddressSpace::Lds,
InMemoryDataOperation::Set>( InMemoryDataOperation::Set>;
{0, n_block_data_on_global}, {0, 0});
// GEMM definition // GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx // c_mtx += transpose(a_mtx) * b_mtx
...@@ -225,14 +223,27 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -225,14 +223,27 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
// zero out threadwise output // zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread); threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
// prepare blockwise copy slicing window
auto a_block_copy_0 = a_block_copy_type({0, m_block_data_on_global}, {0, 0});
auto b_block_copy_0 = b_block_copy_type({0, n_block_data_on_global}, {0, 0});
auto a_block_copy_1 = a_block_copy_0;
auto b_block_copy_1 = b_block_copy_0;
a_block_copy_1.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
b_block_copy_1.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
constexpr auto a_block_slice_copy_steps = Sequence<2 * KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<2 * KPerBlock, 0>{};
// LDS double buffer: preload data into LDS // LDS double buffer: preload data into LDS
{ {
a_blockwise_copy.Run(p_a_global, p_a_block_double); a_block_copy_0.Run(p_a_global, p_a_block_double);
b_blockwise_copy.Run(p_b_global, p_b_block_double); b_block_copy_0.Run(p_b_global, p_b_block_double);
}
constexpr auto a_block_slice_copy_steps = Sequence<KPerBlock, 0>{}; a_block_copy_0.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
constexpr auto b_block_slice_copy_steps = Sequence<KPerBlock, 0>{}; b_block_copy_0.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
}
// LDS double buffer: main body // LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
...@@ -253,24 +264,30 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -253,24 +264,30 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
Float* p_b_block_next = Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double; even_loop ? p_b_block_double + b_block_space : p_b_block_double;
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; auto& a_block_copy_now = even_loop ? a_block_copy_0 : a_block_copy_1;
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; auto& b_block_copy_now = even_loop ? b_block_copy_0 : b_block_copy_1;
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); auto& a_block_copy_next = even_loop ? a_block_copy_1 : a_block_copy_0;
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); auto& b_block_copy_next = even_loop ? b_block_copy_1 : b_block_copy_0;
Float p_a_thread_buffer[a_block_copy_type::GetThreadBufferSize()];
Float p_b_thread_buffer[b_block_copy_type::GetThreadBufferSize()];
__syncthreads(); __syncthreads();
// LDS doubel buffer: load next data from device mem // LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); a_block_copy_next.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); b_block_copy_next.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
a_block_copy_next.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_block_copy_next.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread); blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next); a_block_copy_next.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next); b_block_copy_next.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
} }
} }
...@@ -280,25 +297,22 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -280,25 +297,22 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
if(has_two_iteration_left) // if has 2 iteration left if(has_two_iteration_left) // if has 2 iteration left
{ {
Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_a_thread_buffer[a_block_copy_type::GetThreadBufferSize()];
Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_block_copy_type::GetThreadBufferSize()];
a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
__syncthreads(); __syncthreads();
// LDS double buffer: load last data from device mem // LDS double buffer: load last data from device mem
a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); a_block_copy_1.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); b_block_copy_1.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on even (2nd-last) data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store odd (last) data to LDS
a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, a_block_copy_1.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space); p_a_block_double + a_block_space);
b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, b_block_copy_1.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space); p_b_block_double + b_block_space);
__syncthreads(); __syncthreads();
...@@ -311,7 +325,314 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -311,7 +325,314 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
{ {
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on even (last) data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
}
}
// input: register to global memory
{
constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster;
constexpr index_t M0 = M / M1;
constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster;
constexpr index_t N0 = N / N1;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed(
Sequence<GemmMRepeat, MPerThread, GemmNRepeat, NPerThread>{});
constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor(
c_m_n_global_desc,
make_tuple(UnMerge<Sequence<M0, M1>>{}, UnMerge<Sequence<N0, N1>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t m_thread_data_on_global =
m_block_data_on_global + c_thread_mtx_on_block.row;
const index_t n_thread_data_on_global =
n_block_data_on_global + c_thread_mtx_on_block.col;
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
decltype(c_m0_m1_n0_n1_thread_desc.GetLengths()),
CThreadCopySrcDstAccessOrder,
CThreadCopySrcDstVectorReadWriteDim,
1,
CThreadCopyDstDataPerWrite,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation>(
{0, 0, 0, 0},
{m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1})
.Run(p_c_thread, p_c_global);
}
}
__device__ void Run_double_slice_window(const Float* __restrict__ p_a_global,
const Float* __restrict__ p_b_global,
Float* __restrict__ p_c_global,
Float* __restrict__ p_shared_block) const
{
constexpr auto True = integral_constant<bool, true>{};
constexpr auto a_k_m_global_desc = AGlobalDesc{};
constexpr auto b_k_n_global_desc = BGlobalDesc{};
constexpr auto c_m_n_global_desc = CGlobalDesc{};
constexpr auto K = a_k_m_global_desc.GetLengths()[0];
constexpr auto M = a_k_m_global_desc.GetLengths()[1];
constexpr auto N = b_k_n_global_desc.GetLengths()[1];
// don't do anything if K == 0
if(K == 0)
{
return;
}
// lds max alignment
constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M,
BBlockCopyDstDataPerWrite_N,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N);
// divide block work by [M, N]
static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0,
"wrong! cannot divide work evenly among block");
constexpr index_t MBlockWork = M / MPerBlock;
constexpr index_t NBlockWork = N / NPerBlock;
constexpr auto block_work_desc =
make_cluster_descriptor(Sequence<MBlockWork, NBlockWork>{});
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
const index_t m_block_data_on_global = block_work_id[0] * MPerBlock;
const index_t n_block_data_on_global = block_work_id[1] * NPerBlock;
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, MPerBlock>{}, Number<max_lds_align>{});
// A matrix blockwise copy
using a_block_copy_type =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(a_k_m_global_desc),
decltype(a_k_m_block_desc),
decltype(a_k_m_block_desc.GetLengths()),
ABlockCopyThreadSliceLengths_K_M,
ABlockCopyThreadClusterLengths_K_M,
ABlockCopyThreadClusterArrangeOrder,
ABlockCopySrcAccessOrder,
Sequence<0, 1>,
ABlockCopySrcVectorReadDim,
1,
ABlockCopySrcDataPerRead,
ABlockCopyDstDataPerWrite_M,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>;
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned(
Sequence<KPerBlock, NPerBlock>{}, Number<max_lds_align>{});
// B matrix blockwise copy
using b_block_copy_type =
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
decltype(b_k_n_global_desc),
decltype(b_k_n_block_desc),
decltype(b_k_n_block_desc.GetLengths()),
BBlockCopyThreadSliceLengths_K_N,
BBlockCopyThreadClusterLengths_K_N,
BBlockCopyThreadClusterArrangeOrder,
BBlockCopySrcAccessOrder,
Sequence<0, 1>,
BBlockCopySrcVectorReadDim,
1,
BBlockCopySrcDataPerRead,
BBlockCopyDstDataPerWrite_N,
AddressSpace::Global,
AddressSpace::Vgpr,
AddressSpace::Lds,
InMemoryDataOperation::Set>;
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc);
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc);
// sanity check
static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 &&
NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0,
"wrong!");
constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<GemmMRepeat * MPerThread>{}, Number<GemmNRepeat * NPerThread>{});
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
BlockSize,
decltype(a_k_m_block_mtx_desc),
decltype(b_k_n_block_mtx_desc),
decltype(c_m0m1_n0n1_thread_mtx_desc),
MPerThread,
NPerThread,
MLevel0Cluster,
NLevel0Cluster,
MLevel1Cluster,
NLevel1Cluster,
KPerThread,
ThreadGemmAThreadCopySrcDataPerRead_M,
ThreadGemmBThreadCopySrcDataPerRead_N>{};
// LDS allocation for A and B: be careful of alignment
constexpr index_t a_block_space =
math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align);
constexpr index_t b_block_space =
math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align);
Float* p_a_block_double = p_shared_block;
Float* p_b_block_double = p_shared_block + 2 * a_block_space;
// register allocation for output
AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()];
// zero out threadwise output
threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread);
// prepare blockwise copy slicing window
auto a_block_copy_0 = a_block_copy_type({0, m_block_data_on_global}, {0, 0});
auto b_block_copy_0 = b_block_copy_type({0, n_block_data_on_global}, {0, 0});
#if 0
auto a_block_copy_1 = a_block_copy_0;
auto b_block_copy_1 = b_block_copy_0;
a_block_copy_1.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
b_block_copy_1.MoveSrcSliceWindow(Sequence<KPerBlock, 0>{}, True);
#else
auto a_block_copy_1 = a_block_copy_type({KPerBlock, m_block_data_on_global}, {0, 0});
auto b_block_copy_1 = b_block_copy_type({KPerBlock, n_block_data_on_global}, {0, 0});
#endif
constexpr auto a_block_slice_copy_steps = Sequence<2 * KPerBlock, 0>{};
constexpr auto b_block_slice_copy_steps = Sequence<2 * KPerBlock, 0>{};
// LDS double buffer: preload data into LDS
{
a_block_copy_0.Run(p_a_global, p_a_block_double);
b_block_copy_0.Run(p_b_global, p_b_block_double);
a_block_copy_0.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_block_copy_0.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
}
// LDS double buffer: main body
for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K;
k_block_data_begin += 2 * KPerBlock)
{
#pragma unroll
for(index_t iloop = 0; iloop < 2; ++iloop)
{
const bool even_loop = (iloop % 2 == 0);
Float* p_a_block_now =
even_loop ? p_a_block_double : p_a_block_double + a_block_space;
Float* p_b_block_now =
even_loop ? p_b_block_double : p_b_block_double + b_block_space;
Float* p_a_block_next =
even_loop ? p_a_block_double + a_block_space : p_a_block_double;
Float* p_b_block_next =
even_loop ? p_b_block_double + b_block_space : p_b_block_double;
auto& a_block_copy_now = even_loop ? a_block_copy_0 : a_block_copy_1;
auto& b_block_copy_now = even_loop ? b_block_copy_0 : b_block_copy_1;
auto& a_block_copy_next = even_loop ? a_block_copy_1 : a_block_copy_0;
auto& b_block_copy_next = even_loop ? b_block_copy_1 : b_block_copy_0;
Float p_a_thread_buffer[a_block_copy_type::GetThreadBufferSize()];
Float p_b_thread_buffer[b_block_copy_type::GetThreadBufferSize()];
__syncthreads();
// LDS doubel buffer: load next data from device mem
a_block_copy_next.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_block_copy_next.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
a_block_copy_next.MoveSrcSliceWindow(a_block_slice_copy_steps, True);
b_block_copy_next.MoveSrcSliceWindow(b_block_slice_copy_steps, True);
// LDS double buffer: GEMM on current data
blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread);
// LDS double buffer: store next data to LDS
a_block_copy_next.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next);
b_block_copy_next.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next);
}
}
// LDS double buffer: tail
{
constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0);
if(has_two_iteration_left) // if has 2 iteration left
{
Float p_a_thread_buffer[a_block_copy_type::GetThreadBufferSize()];
Float p_b_thread_buffer[b_block_copy_type::GetThreadBufferSize()];
__syncthreads();
// LDS double buffer: load last data from device mem
a_block_copy_1.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer);
b_block_copy_1.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on even (2nd-last) data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
// LDS double buffer: store odd (last) data to LDS
a_block_copy_1.RunStoreThreadBuffer(p_a_thread_buffer,
p_a_block_double + a_block_space);
b_block_copy_1.RunStoreThreadBuffer(p_b_thread_buffer,
p_b_block_double + b_block_space);
__syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
blockwise_gemm.Run(
p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread);
}
else // if has 1 iteration left
{
__syncthreads();
// LDS double buffer: GEMM on even (last) data
blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread);
} }
} }
...@@ -373,7 +694,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ...@@ -373,7 +694,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
__shared__ Float p_shared_block[shared_block_size]; __shared__ Float p_shared_block[shared_block_size];
Run(p_a_global, p_b_global, p_c_global, p_shared_block); #if 1
Run_single_slice_window(p_a_global, p_b_global, p_c_global, p_shared_block);
#else
Run_double_slice_window(p_a_global, p_b_global, p_c_global, p_shared_block);
#endif
} }
}; };
......
...@@ -189,11 +189,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -189,11 +189,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmMPerThreadSubC = 4; constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadM = 4;
constexpr index_t ThreadGemmDataPerReadN = 4; constexpr index_t ThreadGemmDataPerReadN = 4;
...@@ -237,11 +238,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, ...@@ -237,11 +238,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
GemmKPerBlock, GemmKPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
GemmKPerThreadLoop,
GemmMLevel0Cluster, GemmMLevel0Cluster,
GemmNLevel0Cluster, GemmNLevel0Cluster,
GemmMLevel1Cluster, GemmMLevel1Cluster,
GemmNLevel1Cluster, GemmNLevel1Cluster,
GemmKPerThreadLoop,
ThreadGemmDataPerReadM, ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN, ThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM, GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
......
...@@ -29,13 +29,13 @@ int main(int argc, char* argv[]) ...@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
{ {
using namespace ck; using namespace ck;
#if 1 #if 0
// 1x1 // 1x1
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 64; constexpr index_t C = 512;
constexpr index_t HI = 56; constexpr index_t HI = 28;
constexpr index_t WI = 56; constexpr index_t WI = 28;
constexpr index_t K = 256; constexpr index_t K = 1024;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 1; constexpr index_t X = 1;
...@@ -44,13 +44,13 @@ int main(int argc, char* argv[]) ...@@ -44,13 +44,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 1x7 // 1x7
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 1024; constexpr index_t C = 128;
constexpr index_t HI = 17; constexpr index_t HI = 17;
constexpr index_t WI = 17; constexpr index_t WI = 17;
constexpr index_t K = 1024; constexpr index_t K = 128;
constexpr index_t Y = 1; constexpr index_t Y = 1;
constexpr index_t X = 7; constexpr index_t X = 7;
...@@ -281,7 +281,7 @@ int main(int argc, char* argv[]) ...@@ -281,7 +281,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>; using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>;
#elif 0 #elif 1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr index_t N = 128; constexpr index_t N = 128;
...@@ -327,7 +327,7 @@ int main(int argc, char* argv[]) ...@@ -327,7 +327,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 3>; using LeftPads = Sequence<0, 3>;
using RightPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>;
#elif 1 #elif 0
// 7x1 filter, 3x0 pad, 17x17 input // 7x1 filter, 3x0 pad, 17x17 input
constexpr index_t N = 128; constexpr index_t N = 128;
constexpr index_t C = 128; constexpr index_t C = 128;
...@@ -434,7 +434,7 @@ int main(int argc, char* argv[]) ...@@ -434,7 +434,7 @@ int main(int argc, char* argv[])
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
nrepeat); nrepeat);
#elif 0 #elif 1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment