Commit 00899f19 authored by Chao Liu's avatar Chao Liu
Browse files

implicit gemm v1r2: only load 1d filter

parent 96ee9571
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
#include <unistd.h> #include <unistd.h>
#include "device.hpp" #include "device.hpp"
#include "gridwise_convolution_wrapper.hip.hpp" #include "gridwise_convolution_wrapper.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp"
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp" #include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp"
template <class T, class InDesc, class WeiDesc, class OutDesc> template <class T, class InDesc, class WeiDesc, class OutDesc>
...@@ -78,7 +78,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -78,7 +78,7 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
out_khwn_device_buf.ToDevice(out_khwn.mData.data()); out_khwn_device_buf.ToDevice(out_khwn.mData.data());
#if 0 #if 0
// for 3x3, 34x34, Pascal // for 3x3, 34x34, v1r1, Pascal
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 64; constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 4; constexpr index_t CPerBlock = 4;
...@@ -112,6 +112,40 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -112,6 +112,40 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
constexpr index_t BlockSize = 128; constexpr index_t BlockSize = 128;
#elif 1 #elif 1
// for 3x3, 34x34, v1r2, Pascal
constexpr index_t NPerBlock = 4;
constexpr index_t KPerBlock = 64;
constexpr index_t CPerBlock = 8;
constexpr index_t HoPerBlock = 4;
constexpr index_t WoPerBlock = 8;
constexpr index_t NPerThread = 4;
constexpr index_t KPerThread = 8;
constexpr index_t HoPerThread = 1;
constexpr index_t WoPerThread = 2;
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
constexpr index_t InBlockCopy_ThreadPerDimN = 1;
constexpr index_t InBlockCopyDataPerRead = 4;
constexpr index_t WeiBlockCopyDataPerRead = 4;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmMLevel1Cluster = 2;
constexpr index_t GemmNLevel1Cluster = 2;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
constexpr index_t OutThreadCopyDataPerWrite = 2;
constexpr index_t BlockSize = 128;
#elif 0
// for 3x3, 34x34, Vega 20 // for 3x3, 34x34, Vega 20
constexpr index_t NPerBlock = 16; constexpr index_t NPerBlock = 16;
constexpr index_t KPerBlock = 128; constexpr index_t KPerBlock = 128;
...@@ -406,12 +440,12 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc, ...@@ -406,12 +440,12 @@ void device_implicit_gemm_convolution_1_chwn_cyxk_khwn(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
constexpr auto gridwise_conv = constexpr auto gridwise_conv =
#if 1 #if 0
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
#elif 1 #elif 1
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
#elif 0
GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer
#endif #endif
<GridSize, <GridSize,
BlockSize, BlockSize,
......
...@@ -35,6 +35,29 @@ struct GeneratorTensor_2 ...@@ -35,6 +35,29 @@ struct GeneratorTensor_2
} }
}; };
struct GeneratorTensor_3
{
int min_value = 0;
int max_value = 9;
template <class... Is>
double operator()(Is... is)
{
std::array<index_t, sizeof...(Is)> dims = {{static_cast<index_t>(is)...}};
#if 0
auto f_acc = std::plus<index_t>{};
#else
auto f_acc = [](auto a, auto b){ return 10*a + b;};
#endif
return std::accumulate(dims.begin(),
dims.end(),
index_t(0),
f_acc);
}
};
struct GeneratorTensor_Checkboard struct GeneratorTensor_Checkboard
{ {
template <class... Ts> template <class... Ts>
...@@ -398,18 +421,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result) ...@@ -398,18 +421,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
#if 0 #if 1
constexpr index_t N = 1;
constexpr index_t C = 1;
constexpr index_t HI = 28;
constexpr index_t WI = 28;
constexpr index_t K = 1;
constexpr index_t Y = 3;
constexpr index_t X = 3;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
#elif 1
// 3x3, 34x34 // 3x3, 34x34
constexpr index_t N = 64; constexpr index_t N = 64;
constexpr index_t C = 256; constexpr index_t C = 256;
...@@ -656,13 +668,10 @@ int main(int argc, char* argv[]) ...@@ -656,13 +668,10 @@ int main(int argc, char* argv[])
#if 0 #if 0
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 0
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
#elif 1 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
#elif 0 #elif 1
in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
auto gen_wei = [](auto... is) { auto gen_wei = [](auto... is) {
...@@ -681,7 +690,7 @@ int main(int argc, char* argv[]) ...@@ -681,7 +690,7 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1 #elif 1
device_implicit_gemm_convolution_1_chwn_cyxk_khwn device_implicit_gemm_convolution_1_chwn_cyxk_khwn
#elif 1 #elif 0
device_implicit_gemm_convolution_2_chwn_cyxk_khwn device_implicit_gemm_convolution_2_chwn_cyxk_khwn
#endif #endif
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
......
...@@ -8,6 +8,13 @@ __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>) ...@@ -8,6 +8,13 @@ __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
return Sequence<L1, 1>{}; return Sequence<L1, 1>{};
} }
// this is ugly, only for 3d
template <index_t L0, index_t L1, index_t L2>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2>)
{
return Sequence<L1 * L2, L2, 1>{};
}
// this is ugly, only for 4d // this is ugly, only for 4d
template <index_t L0, index_t L1, index_t L2, index_t L3> template <index_t L0, index_t L1, index_t L2, index_t L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>) __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
...@@ -79,6 +86,15 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0 ...@@ -79,6 +86,15 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0
return Sequence<L1_align, 1>{}; return Sequence<L1_align, 1>{};
} }
// this is ugly, only for 3d
template <index_t L0, index_t L1, index_t L2, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2>,
Number<Align>)
{
constexpr index_t L2_align = Align * ((L2 + Align - 1) / Align);
return Sequence<L1 * L2_align, L2_align, 1>{};
}
// this is ugly, only for 4d // this is ugly, only for 4d
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align> template <index_t L0, index_t L1, index_t L2, index_t L3, index_t Align>
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>, __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
...@@ -244,6 +260,22 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) ...@@ -244,6 +260,22 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I0), desc.GetStride(I0),
desc.GetStride(I1)); desc.GetStride(I1));
} }
else if(ndim == 3)
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
printf("%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n",
s,
desc.GetDimension(),
desc.GetLength(I0),
desc.GetLength(I1),
desc.GetLength(I2),
desc.GetStride(I0),
desc.GetStride(I1),
desc.GetStride(I2));
}
else if(ndim == 4) else if(ndim == 4)
{ {
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
......
...@@ -14,7 +14,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -14,7 +14,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths()); constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
#if 0 #if 0
if(threadIdx.x == 0) if(get_thread_local_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: "); print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: "); print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
...@@ -25,7 +25,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -25,7 +25,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
for(index_t iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
index_t is = threadIdx.x + iloop * BlockSize; index_t is = get_thread_local_1d_id() + iloop * BlockSize;
const index_t did0 = is / desc.GetStride(I0); const index_t did0 = is / desc.GetStride(I0);
...@@ -42,7 +42,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -42,7 +42,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
if(has_tail) if(has_tail)
{ {
index_t is = threadIdx.x + NLoop * BlockSize; index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < desc.GetElementSize()) if(is < desc.GetElementSize())
{ {
...@@ -59,7 +59,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -59,7 +59,7 @@ blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
} }
} }
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3] // Function: p_dst[reorder[i0], reorder[i1] = p_src[i0,i1]
// TODO: in order to optimize mem access for different mem type, // TODO: in order to optimize mem access for different mem type,
// need to write specialized version // need to write specialized version
template <index_t BlockSize, template <index_t BlockSize,
...@@ -92,7 +92,7 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -92,7 +92,7 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
for(index_t iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
index_t is = threadIdx.x + iloop * BlockSize; index_t is = get_thread_local_1d_id() + iloop * BlockSize;
index_t did[2]; index_t did[2];
...@@ -113,7 +113,7 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -113,7 +113,7 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
if(has_tail) if(has_tail)
{ {
index_t is = threadIdx.x + NLoop * BlockSize; index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize()) if(is < ref_desc.GetElementSize())
{ {
...@@ -162,15 +162,96 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, ...@@ -162,15 +162,96 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
} }
template <index_t BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths> template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
index_t DataPerRead>
struct Blockwise2dTensorCopy1 struct Blockwise2dTensorCopy1
{ {
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
__device__ constexpr Blockwise2dTensorCopy1()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
static_assert(DataPerRead == 1 ||
(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1),
"wrong! only support stride1 == 1 if DataPerRead > 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
DstDesc{}.GetStride(I0) % DataPerRead == 0,
"src and dst stride2 should be multiple of DataPerRead to keep alignment");
// we allow out-of-bound read from src in D1 dimension,
// but we need to make sure dst stride0 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead);
static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
"wrong! out-of-bound write will contaminate next line!\n");
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const __device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{ {
constexpr auto dst_from_src_reorder = Sequence<0, 1>{}; constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>( constexpr auto src_desc = SrcDesc{};
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder); constexpr auto dst_desc = DstDesc{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead);
constexpr auto ref_desc =
make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
auto f_copy = [&](index_t is) {
index_t did[4];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
const index_t src_index =
src_desc.Get1dIndex(did[0], did[1] * DataPerRead);
const index_t dst_index =
dst_desc.Get1dIndex(did[0], did[1] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
};
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
f_copy(is);
}
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
f_copy(is);
}
}
} }
}; };
......
#pragma once
#include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp"
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class CopyLengths,
index_t DataPerRead>
struct Blockwise3dTensorCopy1
{
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
__device__ constexpr Blockwise3dTensorCopy1()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
static_assert(DataPerRead == 1 ||
(SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1),
"wrong! only support stride2 == 1 if DataPerRead > 1!\n");
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(SrcDesc{}.GetStride(I1) % DataPerRead == 0 &&
DstDesc{}.GetStride(I1) % DataPerRead == 0,
"src and dst stride1 should be multiple of DataPerRead to keep alignment");
// we allow out-of-bound read from src in D3 dimension,
// but we need to make sure dst stride2 is big enough,
// so that the out-of-bound write won't contaminate next line in dst
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t read_per_d2 = integer_divide_ceil(L2, DataPerRead);
static_assert(read_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
"wrong! out-of-bound write will contaminate next line!\n");
}
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
constexpr index_t L0 = CopyLengths{}.Get(I0);
constexpr index_t L1 = CopyLengths{}.Get(I1);
constexpr index_t L2 = CopyLengths{}.Get(I2);
constexpr index_t read_per_d2 = integer_divide_ceil(L2, DataPerRead);
constexpr auto ref_desc =
make_ConstantTensorDescriptor(Sequence<L0, L1, read_per_d2>{});
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
auto f_copy = [&](index_t is) {
index_t did[3];
did[0] = is / ref_desc.GetStride(I0);
is -= did[0] * ref_desc.GetStride(I0);
did[1] = is / ref_desc.GetStride(I1);
is -= did[1] * ref_desc.GetStride(I1);
did[2] = is / ref_desc.GetStride(I2);
const index_t src_index =
src_desc.Get1dIndex(did[0], did[1], did[2] * DataPerRead);
const index_t dst_index =
dst_desc.Get1dIndex(did[0], did[1], did[2] * DataPerRead);
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
};
for(index_t iloop = 0; iloop < NLoop; ++iloop)
{
index_t is = get_thread_local_1d_id() + iloop * BlockSize;
f_copy(is);
}
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
if(has_tail)
{
index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize())
{
f_copy(is);
}
}
}
};
...@@ -15,7 +15,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -15,7 +15,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths()); constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
#if 0 #if 0
if(threadIdx.x == 0) if(get_thread_local_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: "); print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: "); print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
...@@ -26,7 +26,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -26,7 +26,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
for(index_t iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
index_t is = threadIdx.x + iloop * BlockSize; index_t is = get_thread_local_1d_id() + iloop * BlockSize;
const index_t did0 = is / desc.GetStride(I0); const index_t did0 = is / desc.GetStride(I0);
...@@ -51,7 +51,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst ...@@ -51,7 +51,7 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst
if(has_tail) if(has_tail)
{ {
index_t is = threadIdx.x + NLoop * BlockSize; index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < desc.GetElementSize()) if(is < desc.GetElementSize())
{ {
...@@ -113,7 +113,7 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -113,7 +113,7 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
for(index_t iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
index_t is = threadIdx.x + iloop * BlockSize; index_t is = get_thread_local_1d_id() + iloop * BlockSize;
index_t did[4]; index_t did[4];
...@@ -142,7 +142,7 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds ...@@ -142,7 +142,7 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
if(has_tail) if(has_tail)
{ {
index_t is = threadIdx.x + NLoop * BlockSize; index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize()) if(is < ref_desc.GetElementSize())
{ {
...@@ -287,7 +287,7 @@ struct Blockwise4dTensorCopy1 ...@@ -287,7 +287,7 @@ struct Blockwise4dTensorCopy1
for(index_t iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
index_t is = threadIdx.x + iloop * BlockSize; index_t is = get_thread_local_1d_id() + iloop * BlockSize;
f_copy(is); f_copy(is);
} }
...@@ -296,7 +296,7 @@ struct Blockwise4dTensorCopy1 ...@@ -296,7 +296,7 @@ struct Blockwise4dTensorCopy1
if(has_tail) if(has_tail)
{ {
index_t is = threadIdx.x + NLoop * BlockSize; index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize()) if(is < ref_desc.GetElementSize())
{ {
...@@ -370,7 +370,7 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -370,7 +370,7 @@ struct BlockwiseChwnTensorCopyPadded
for(index_t iloop = 0; iloop < NLoop; ++iloop) for(index_t iloop = 0; iloop < NLoop; ++iloop)
{ {
index_t is = threadIdx.x + iloop * BlockSize; index_t is = get_thread_local_1d_id() + iloop * BlockSize;
index_t did[4]; index_t did[4];
...@@ -401,7 +401,7 @@ struct BlockwiseChwnTensorCopyPadded ...@@ -401,7 +401,7 @@ struct BlockwiseChwnTensorCopyPadded
if(has_tail) if(has_tail)
{ {
index_t is = threadIdx.x + NLoop * BlockSize; index_t is = get_thread_local_1d_id() + NLoop * BlockSize;
if(is < ref_desc.GetElementSize()) if(is < ref_desc.GetElementSize())
{ {
......
...@@ -250,6 +250,15 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -250,6 +250,15 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
} }
} }
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
p_a_thread[0], p_a_thread[1], p_a_thread[2], p_a_thread[3], p_a_thread[4], p_a_thread[5], p_a_thread[6], p_a_thread[7],
p_b_thread[0], p_b_thread[1], p_b_thread[2], p_b_thread[3], p_b_thread[4], p_b_thread[5], p_b_thread[6], p_b_thread[7]);
}
#endif
threadwise_gemm(a_thread_mtx, threadwise_gemm(a_thread_mtx,
True, True,
p_a_thread, p_a_thread,
...@@ -313,6 +322,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 ...@@ -313,6 +322,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
MPerThread == 8 && NPerThread == 8, MPerThread == 8 && NPerThread == 8,
"Run_asm cannot deal with this GEMM shape yet\n"); "Run_asm cannot deal with this GEMM shape yet\n");
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");
static_assert( static_assert(
BlockMatrixStrideA == 0 && BatchPerThread == 1, BlockMatrixStrideA == 0 && BatchPerThread == 1,
"Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n"); "Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n");
......
...@@ -42,7 +42,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -42,7 +42,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr index_t XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread; constexpr index_t XThreadWork = (out_block_desc.GetLength(I3) + WoPerThread - 1) / WoPerThread;
#if 0 #if 0
if(threadIdx.x == 0) if(get_thread_local_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(in_block_desc); print_ConstantTensorDescriptor(in_block_desc);
print_ConstantTensorDescriptor(wei_block_desc); print_ConstantTensorDescriptor(wei_block_desc);
...@@ -68,7 +68,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc, ...@@ -68,7 +68,7 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
constexpr auto out_thread_block_desc = constexpr auto out_thread_block_desc =
make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides()); make_ConstantTensorDescriptor(out_thread_desc.GetLengths(), out_block_desc.GetStrides());
const index_t thread_id = threadIdx.x; const index_t thread_id = get_thread_local_1d_id();
for(index_t thread_work_id = thread_id; for(index_t thread_work_id = thread_id;
thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork; thread_work_id < NThreadWork * KThreadWork * YThreadWork * XThreadWork;
......
...@@ -176,6 +176,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 ...@@ -176,6 +176,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
MPerThread == 8 && NPerThread == 8, MPerThread == 8 && NPerThread == 8,
"Run_asm cannot deal with this GEMM shape yet\n"); "Run_asm cannot deal with this GEMM shape yet\n");
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n");
using Float4 = vector_type<float, 4>::MemoryType; using Float4 = vector_type<float, 4>::MemoryType;
Float4* reg_a = (Float4*)(p_a_thread); Float4* reg_a = (Float4*)(p_a_thread);
......
...@@ -36,7 +36,7 @@ template <index_t GridSize, ...@@ -36,7 +36,7 @@ template <index_t GridSize,
index_t InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead, index_t WeiBlockCopyDataPerRead,
index_t OutThreadCopyDataPerWrite> index_t OutThreadCopyDataPerWrite>
struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
...@@ -100,7 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -100,7 +100,7 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
// tensor view of blockwise input and weight in LDS // tensor view of blockwise input and weight in LDS
// be careful of alignment // be careful of alignment
constexpr index_t max_align = constexpr index_t max_align =
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{}); Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
...@@ -118,6 +118,14 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -118,6 +118,14 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
const auto blockwise_in_copy = const auto blockwise_in_copy =
#if 1
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_chwn_global_desc),
decltype(in_chwn_block_desc),
decltype(in_chwn_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#else
Blockwise4dTensorCopy3<BlockSize, Blockwise4dTensorCopy3<BlockSize,
Float, Float,
decltype(in_chwn_global_desc), decltype(in_chwn_global_desc),
...@@ -125,6 +133,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn ...@@ -125,6 +133,8 @@ struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn
decltype(in_chwn_block_desc.GetLengths()), decltype(in_chwn_block_desc.GetLengths()),
InBlockCopyThreadPerDims, InBlockCopyThreadPerDims,
InBlockCopyDataPerRead>{}; InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock] // format is [CPerBlock*Y*X,KPerBlock]
......
...@@ -36,7 +36,7 @@ template <index_t GridSize, ...@@ -36,7 +36,7 @@ template <index_t GridSize,
index_t InBlockCopyDataPerRead, index_t InBlockCopyDataPerRead,
index_t WeiBlockCopyDataPerRead, index_t WeiBlockCopyDataPerRead,
index_t OutThreadCopyDataPerWrite> index_t OutThreadCopyDataPerWrite>
struct GridwiseConvolutionImplicitGemm_v1_chwn_cyxk_khwn_lds_double_buffer struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
{ {
__device__ void Run(const Float* const __restrict__ p_in_global, __device__ void Run(const Float* const __restrict__ p_in_global,
const Float* const __restrict__ p_wei_global, const Float* const __restrict__ p_wei_global,
......
...@@ -2,8 +2,9 @@ ...@@ -2,8 +2,9 @@
#include "common.hip.hpp" #include "common.hip.hpp"
#include "ConstantTensorDescriptor.hip.hpp" #include "ConstantTensorDescriptor.hip.hpp"
#include "ConstantMatrixDescriptor.hip.hpp" #include "ConstantMatrixDescriptor.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "blockwise_2d_tensor_op.hip.hpp" #include "blockwise_2d_tensor_op.hip.hpp"
#include "blockwise_3d_tensor_op.hip.hpp"
#include "blockwise_4d_tensor_op.hip.hpp"
#include "threadwise_nd_tensor_op.hip.hpp" #include "threadwise_nd_tensor_op.hip.hpp"
#include "threadwise_4d_tensor_op.hip.hpp" #include "threadwise_4d_tensor_op.hip.hpp"
#include "blockwise_batched_gemm.hip.hpp" #include "blockwise_batched_gemm.hip.hpp"
...@@ -52,19 +53,19 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -52,19 +53,19 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
constexpr auto I2 = Number<2>{}; constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{}; constexpr auto I3 = Number<3>{};
constexpr auto in_chwn_global_desc = InGlobalDesc{}; constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
constexpr auto out_khwn_global_desc = OutGlobalDesc{}; constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
constexpr index_t C = in_chwn_global_desc.GetLength(I0); constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
constexpr index_t K = out_khwn_global_desc.GetLength(I0); constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
constexpr index_t N = out_khwn_global_desc.GetLength(I3); constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
constexpr index_t HiPerBlock = HoPerBlock + Y - 1; constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
constexpr index_t WiPerBlock = WoPerBlock + X - 1; constexpr index_t WiPerBlock = WoPerBlock + X - 1;
...@@ -94,53 +95,54 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -94,53 +95,54 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
const index_t hi_block_data_begin = ho_block_data_begin; const index_t hi_block_data_begin = ho_block_data_begin;
const index_t wi_block_data_begin = wo_block_data_begin; const index_t wi_block_data_begin = wo_block_data_begin;
// 2d tensor view of gridwise weight // global tensor view
constexpr auto wei_ck_global_desc = constexpr auto wei_c_x_k_global_desc =
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{}); make_ConstantTensorDescriptor(Sequence<C, X, K>{}, Sequence<Y * X * K, K, 1>{});
// tensor view of blockwise input and weight in LDS // LDS tensor view
// be careful of alignment // be careful of alignment
constexpr index_t max_align = constexpr index_t max_align =
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); mod_conv::max(InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{}); Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
constexpr auto wei_ck_block_desc = make_ConstantTensorDescriptor_aligned( constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
Sequence<CPerBlock, KPerBlock>{}, Number<max_align>{}); Sequence<CPerBlock, X, KPerBlock>{}, Number<max_align>{});
// tensor view of threadwise output in register // tensor view of threadwise output in register
constexpr auto out_khwn_thread_desc = make_ConstantTensorDescriptor( constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{}); Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
// blockwise copy // blockwise copy
// input: format is [C, Hi, Wi, N] // input: format is [C, Hi, Wi, N]
const auto blockwise_in_copy = const auto blockwise_in_copy =
#if 1
Blockwise4dTensorCopy1<BlockSize,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_block_desc),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
#else
Blockwise4dTensorCopy3<BlockSize, Blockwise4dTensorCopy3<BlockSize,
Float, Float,
decltype(in_chwn_global_desc), decltype(in_c_h_w_n_global_desc),
decltype(in_chwn_block_desc), decltype(in_c_h_w_n_block_desc),
decltype(in_chwn_block_desc.GetLengths()), decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyThreadPerDims, InBlockCopyThreadPerDims,
InBlockCopyDataPerRead>{}; InBlockCopyDataPerRead>{};
#endif
// blockwise wei copy // blockwise wei copy
// format is [CPerBlock, KPerBlock] // format is [CPerBlock, X * KPerBlock]
const auto blockwise_wei_copy = const auto blockwise_wei_copy =
#if 0 // debug Blockwise3dTensorCopy1<BlockSize,
Blockwise2dTensorCopy1<BlockSize,
Float,
decltype(wei_ck_global_desc),
decltype(wei_ck_block_desc),
decltype(wei_ck_block_desc.GetLengths())>{};
#else
Blockwise2dTensorCopy3<BlockSize,
Float, Float,
decltype(wei_ck_global_desc), decltype(wei_c_x_k_global_desc),
decltype(wei_ck_block_desc), decltype(wei_c_x_k_block_desc),
decltype(wei_ck_block_desc.GetLengths()), decltype(wei_c_x_k_block_desc.GetLengths()),
WeiBlockCopyDataPerRead>{}; WeiBlockCopyDataPerRead>{};
#endif
// a series of blockwise batched GEMM // a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix // C_matrix += transpose(A_matrix) * B_matrix
...@@ -148,30 +150,30 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -148,30 +150,30 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// A_matrix[C,K] is a sub-matrix of wei_block[C,K] // A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr auto a_cxk_block_mtx_desc = constexpr auto a_c_k_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{}, make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<KPerBlock>{}, Number<KPerBlock>{},
Number<wei_ck_block_desc.GetStride(I0)>{}); Number<wei_c_x_k_block_desc.GetStride(I0)>{});
constexpr auto b_cxwn_block_mtx_desc = constexpr auto b_c_wn_block_mtx_desc =
make_ConstantMatrixDescriptor(Number<CPerBlock>{}, make_ConstantMatrixDescriptor(Number<CPerBlock>{},
Number<WoPerBlock * NPerBlock>{}, Number<WoPerBlock * NPerBlock>{},
Number<in_chwn_block_desc.GetStride(I0)>{}); Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
constexpr auto c_kxwn_thread_mtx_desc = constexpr auto c_k_wn_thread_mtx_desc =
make_ConstantMatrixDescriptor(Number<KPerThread>{}, make_ConstantMatrixDescriptor(Number<KPerThread>{},
Number<WoPerThread * NPerThread>{}, Number<WoPerThread * NPerThread>{},
Number<out_khwn_thread_desc.GetStride(I0)>{}); Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
const auto blockwise_batch_gemm = const auto blockwise_batch_gemm =
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
BlockSize, BlockSize,
decltype(a_cxk_block_mtx_desc), decltype(a_c_k_block_mtx_desc),
decltype(b_cxwn_block_mtx_desc), decltype(b_c_wn_block_mtx_desc),
decltype(c_kxwn_thread_mtx_desc), decltype(c_k_wn_thread_mtx_desc),
0, 0,
in_chwn_block_desc.GetStride(I1), in_c_h_w_n_block_desc.GetStride(I1),
out_khwn_thread_desc.GetStride(I1), out_k_h_w_n_thread_desc.GetStride(I1),
HoPerBlock, HoPerBlock,
GemmMPerThreadSubC, GemmMPerThreadSubC,
GemmNPerThreadSubC, GemmNPerThreadSubC,
...@@ -185,64 +187,64 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -185,64 +187,64 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
GemmDataPerReadB>{}; GemmDataPerReadB>{};
// LDS: be careful of alignment // LDS: be careful of alignment
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{}); constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
constexpr index_t wei_block_space = wei_ck_block_desc.GetElementSpace(Number<max_align>{}); constexpr index_t wei_block_space = wei_c_x_k_block_desc.GetElementSpace(Number<max_align>{});
__shared__ Float p_in_block[in_block_space]; __shared__ Float p_in_block[in_block_space];
__shared__ Float p_wei_block[wei_block_space]; __shared__ Float p_wei_block[wei_block_space];
// register // register
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()]; Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()];
#if 0 #if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(in_chwn_global_desc, "in_chwn_global_desc"); print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
print_ConstantTensorDescriptor(wei_cyxk_global_desc, "wei_cyxk_global_desc"); print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
print_ConstantTensorDescriptor(wei_ck_global_desc, "wei_ck_global_desc");
print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
print_ConstantTensorDescriptor(wei_ck_block_desc, "wei_ck_block_desc"); print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
} }
#endif #endif
// set threadwise output tensor to 0 // set threadwise output tensor to 0
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread); threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
const Float* p_in_global_block_offset = const Float* p_in_global_block_offset =
p_in_global + p_in_global +
in_chwn_global_desc.Get1dIndex( in_c_h_w_n_global_desc.Get1dIndex(
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
const Float* p_wei_global_block_offset = const Float* p_wei_global_block_offset =
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin); p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
p_in_global_block_offset += CPerBlock * in_chwn_global_desc.GetStride(I0), p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0)) p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
{ {
// input: global mem to LDS
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
for(index_t y = 0; y < Y; ++y) for(index_t y = 0; y < Y; ++y)
{ {
for(index_t x = 0; x < X; ++x) blockwise_in_copy.Run(p_in_global_block_offset +
{ in_c_h_w_n_global_desc.Get1dIndex(0, y, 0, 0),
// weight: global mem to LDS p_in_block);
blockwise_wei_copy.Run(p_wei_global_block_offset + blockwise_wei_copy.Run(p_wei_global_block_offset +
wei_cyxk_global_desc.Get1dIndex(0, y, x, 0), wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0),
p_wei_block); p_wei_block);
__syncthreads(); __syncthreads();
blockwise_batch_gemm.Run(p_wei_block, for(index_t x = 0; x < X; ++x)
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0), {
blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0),
p_in_block + in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0),
p_out_thread); p_out_thread);
__syncthreads();
} }
__syncthreads();
} }
} }
...@@ -324,7 +326,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn ...@@ -324,7 +326,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
p_out_thread, p_out_thread,
out_10d_global_desc, out_10d_global_desc,
p_out_global + p_out_global +
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin, out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
ho_block_data_begin + ho_thread_data_begin, ho_block_data_begin + ho_thread_data_begin,
wo_block_data_begin + wo_thread_data_begin, wo_block_data_begin + wo_thread_data_begin,
n_block_data_begin + n_thread_data_begin), n_block_data_begin + n_thread_data_begin),
......
...@@ -125,7 +125,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i ...@@ -125,7 +125,7 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
const index_t thread_id = threadIdx.x; const index_t thread_id = get_thread_local_1d_id();
itmp = thread_id; itmp = thread_id;
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
......
...@@ -137,7 +137,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ...@@ -137,7 +137,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread;
constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread;
const index_t thread_id = threadIdx.x; const index_t thread_id = get_thread_local_1d_id();
itmp = thread_id; itmp = thread_id;
const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork);
......
...@@ -10,7 +10,7 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re ...@@ -10,7 +10,7 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re
constexpr auto desc = Desc{}; constexpr auto desc = Desc{};
#if 0 #if 0
if(threadIdx.x == 0) if(get_thread_local_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: "); print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
} }
...@@ -112,7 +112,7 @@ __device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDi ...@@ -112,7 +112,7 @@ __device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
constexpr auto desc = Desc{}; constexpr auto desc = Desc{};
#if 0 #if 0
if(threadIdx.x == 0) if(get_thread_local_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: "); print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
} }
......
...@@ -12,7 +12,7 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re ...@@ -12,7 +12,7 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
constexpr auto desc = Desc{}; constexpr auto desc = Desc{};
#if 0 #if 0
if(threadIdx.x == 0) if(get_thread_local_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: "); print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
} }
...@@ -218,7 +218,7 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi ...@@ -218,7 +218,7 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDi
constexpr auto desc = Desc{}; constexpr auto desc = Desc{};
#if 0 #if 0
if(threadIdx.x == 0) if(get_thread_local_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: "); print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
} }
......
...@@ -20,7 +20,7 @@ __device__ void threadwise_direct_convolution_1(InDesc, ...@@ -20,7 +20,7 @@ __device__ void threadwise_direct_convolution_1(InDesc,
constexpr auto out_desc = OutDesc{}; constexpr auto out_desc = OutDesc{};
#if 0 #if 0
if(blockIdx.x == 0 && threadIdx.x == 0) if(blockIdx.x == 0 && get_thread_local_1d_id() == 0)
{ {
print_ConstantTensorDescriptor(in_desc, "threadwise_direct_convolution: in_desc: "); print_ConstantTensorDescriptor(in_desc, "threadwise_direct_convolution: in_desc: ");
print_ConstantTensorDescriptor(wei_desc, "threadwise_direct_convolution: wei_desc: "); print_ConstantTensorDescriptor(wei_desc, "threadwise_direct_convolution: wei_desc: ");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment