"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "70645f4d7d1447d2c5aa0667b88af92a20018b17"
Unverified Commit 3406a114 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Update for recent MIOpen integration (#11)

* update for MIOpen integration
parent c5da0377
......@@ -49,7 +49,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
......@@ -85,11 +84,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
"be violated");
// output tensor
constexpr auto out_n_k_howo_global_desc =
unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3);
constexpr auto out_k_b_global_desc =
transform_tensor_descriptor(out_n_k_howo_global_desc,
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
......
......@@ -353,7 +353,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
}
{
#if 1 // debug
#if 1 // debug
// input: register to global memory, atomic add
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
? InMemoryDataOperation::none
......
......@@ -81,11 +81,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
"be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
......@@ -115,10 +115,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough<C>{},
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>>{},
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -135,10 +135,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
......@@ -110,11 +110,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
"be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
......@@ -146,11 +146,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<C>{},
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
wei_skip_all_out_of_bound_check>{},
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -168,11 +168,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
out_skip_all_out_of_bound_check>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......
......@@ -22,8 +22,6 @@ template <index_t GridSize,
typename ConvDilations,
typename InLeftPads,
typename InRightPads,
index_t Iter_ytilda,
index_t Iter_xtilda,
index_t GemmMPerBlock,
index_t GemmNPerBlock,
index_t GemmKPerBlock,
......@@ -47,9 +45,27 @@ template <index_t GridSize,
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
{
__device__ void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global) const
__host__ __device__ static constexpr index_t GetNumberOfGemm()
{
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
return Ytilda * Xtilda;
}
template <index_t iYTilda, index_t iXTilda>
__device__ static void RunImpl(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global)
{
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
......@@ -83,11 +99,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
"be violated");
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
......@@ -119,11 +135,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<C>{},
Embed<Y,
Sequence<Ydot, Ytilda>,
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
wei_skip_all_out_of_bound_check>{},
Embed<X,
Sequence<Xdot, Xtilda>,
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
wei_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -141,11 +157,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
PassThrough<K>{},
Embed<Ho,
Sequence<Ydot, Htilda>,
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
out_skip_all_out_of_bound_check>{},
Embed<Wo,
Sequence<Xdot, Wtilda>,
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
out_skip_all_out_of_bound_check>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
......@@ -215,8 +231,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
// GEMM
constexpr index_t ytilda = Iter_ytilda;
constexpr index_t xtilda = Iter_xtilda;
constexpr index_t ytilda = iYTilda;
constexpr index_t xtilda = iXTilda;
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
......@@ -327,6 +343,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
}
template <index_t GemmId>
__device__ static void Run(Float* __restrict__ p_in_global,
const Float* __restrict__ p_wei_global,
const Float* __restrict__ p_out_global)
{
constexpr index_t ConvStrideH = ConvStrides{}[0];
constexpr index_t ConvStrideW = ConvStrides{}[1];
constexpr index_t ConvDilationH = ConvDilations{}[0];
constexpr index_t ConvDilationW = ConvDilations{}[1];
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t iYTilda = GemmId / Xtilda;
constexpr index_t iXTilda = GemmId % Xtilda;
static_assert(iYTilda < Ytilda && iXTilda < Xtilda, "wrong! iYtilda, iXtilda");
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
}
};
} // namespace ck
......
......@@ -49,7 +49,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
const Float* const __restrict__ p_wei_global,
Float* const __restrict__ p_out_global) const
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
constexpr auto I3 = Number<3>{};
......@@ -117,9 +116,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
// output tensor
constexpr auto out_k_b_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
// GEMM
......
......@@ -47,6 +47,9 @@ struct PassThrough
}
};
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
// necessary
// However, the check will be skipped if SkipIsValidCheck is set to true by user
// LowerLengths: Sequence<...>
template <typename LowerLengths,
typename LeftPads,
......@@ -92,12 +95,12 @@ struct Pad
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
#if 1 // debug
// skip valid check if user request it
if(SkipIsValidCheck)
{
return true;
}
#endif
bool flag = true;
for(index_t i = 0; i < nDim; ++i)
......@@ -384,6 +387,9 @@ struct UnMerge
}
};
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
// necessary
// However, the check will be skipped if SkipIsValidCheck is set to true by user
// UpperLengths: Sequence<...>
// Coefficients: Sequence<...>
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
......@@ -442,12 +448,12 @@ struct Embed
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
#if 1 // debug
// skip valid check if user request it
if(SkipIsValidCheck)
{
return true;
}
#endif
bool flag = true;
index_t ncorner = 1;
......
......@@ -112,11 +112,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
}
}
......@@ -144,11 +144,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<DstData,
DstDataPerWrite,
AddressSpace::vgpr,
DstAddressSpace,
DstInMemOp>(
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::vgpr,
DstAddressSpace,
DstInMemOp>(
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
}
}
......@@ -262,15 +262,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(p_src,
src_nonlinear_coord.GetOffset() +
src_linear_offset,
p_src_long_vector,
buffer_offset);
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(p_src,
src_nonlinear_coord.GetOffset() +
src_linear_offset,
p_src_long_vector,
buffer_offset);
}
}
......@@ -301,11 +301,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<DstData,
DstDataPerWrite,
AddressSpace::vgpr,
DstAddressSpace,
DstInMemOp>(
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::vgpr,
DstAddressSpace,
DstInMemOp>(
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
}
}
......@@ -401,11 +401,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(
transfer_data<SrcData,
SrcDataPerRead,
SrcAddressSpace,
AddressSpace::vgpr,
InMemoryDataOperation::none>(
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
}
}
......@@ -446,14 +446,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
// has the valid/invalid mapping situation
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
move_data<DstData,
DstDataPerWrite,
AddressSpace::vgpr,
DstAddressSpace,
DstInMemOp>(p_dst_long_vector,
buffer_offset,
p_dst,
dst_nonlinear_coord.GetOffset() + dst_linear_offset);
transfer_data<DstData,
DstDataPerWrite,
AddressSpace::vgpr,
DstAddressSpace,
DstInMemOp>(p_dst_long_vector,
buffer_offset,
p_dst,
dst_nonlinear_coord.GetOffset() +
dst_linear_offset);
}
}
});
......
......@@ -8,19 +8,12 @@ namespace ck {
// outer-product: c[i,j] += inner_product(a[i], b[j])
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
{
// disable inline asm due to the compiler issue: SWDEV-202749
///\to-do: enable the inline asm after the compiler fix
#if CK_WORKAROUND_SWDEV_202749
c0 += a * b0;
c1 += a * b1;
#else
asm volatile("\n \
v_mac_f32 %0, %2, %3 \n \
v_mac_f32 %1, %2, %4 \n \
"
: "=v"(c0), "=v"(c1)
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
#endif
}
// outer-product: c[i,j] += inner_product(a[i], b[j])
......
......@@ -43,6 +43,10 @@
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
#endif
#ifndef CK_USE_AMD_XDLOPS_EMULATE
#define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes
#endif
// experimental implementation
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
......@@ -51,9 +55,6 @@
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
// workaround
#define CK_WORKAROUND_SWDEV_202749 1
namespace ck {
enum AddressSpace
......
......@@ -70,7 +70,7 @@ template <typename T,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp>
__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{
static_assert(DstInMemOp == InMemoryDataOperation::none ||
DstInMemOp == InMemoryDataOperation::atomic_add,
......
......@@ -38,7 +38,7 @@ template <typename T,
AddressSpace SrcAddressSpace,
AddressSpace DstAddressSpace,
InMemoryDataOperation DstInMemOp>
__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
{
static_assert(DstInMemOp == InMemoryDataOperation::none ||
DstInMemOp == InMemoryDataOperation::atomic_add,
......
......@@ -103,9 +103,9 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
return x < y ? x : y;
}
// highest common factor
// greatest common divisor, aka highest common factor
template <typename T>
__host__ __device__ constexpr T hcf(T x, T y)
__host__ __device__ constexpr T gcd(T x, T y)
{
if(x == 0)
{
......@@ -124,30 +124,30 @@ __host__ __device__ constexpr T hcf(T x, T y)
if(x > y)
{
return hcf(x - y, y);
return gcd(x - y, y);
}
return hcf(x, y - x);
return gcd(x, y - x);
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto hcf(Number<X>, Number<Y>)
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
{
constexpr auto result = hcf(X, Y);
constexpr auto result = gcd(X, Y);
return Number<result>{};
}
template <typename X, typename... Ys>
__host__ __device__ constexpr auto hcf(X x, Ys... ys)
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
{
return hcf(x, ys...);
return gcd(x, ys...);
}
// least common multiple
template <typename T>
__host__ __device__ constexpr T lcm(T x, T y)
{
return (x * y) / hcf(x, y);
return (x * y) / gcd(x, y);
}
template <typename X, typename Y, typename... Zs>
......
......@@ -152,11 +152,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
......
......@@ -91,11 +91,11 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
......
......@@ -2,13 +2,18 @@
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
namespace launcher {
using namespace ck;
template <typename GridwiseOp, index_t GemmId, typename... Xs>
__global__ void run_gridwise_convolution_backward_data_v4r1(Xs... xs)
{
GridwiseOp::template Run<GemmId>(xs...);
}
template <typename T,
typename InDesc,
typename WeiDesc,
......@@ -119,11 +124,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#endif
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
......@@ -154,69 +159,61 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
for(index_t i = 0; i < nrepeat; ++i)
{
KernelTimer timer;
using GridwiseConv = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>;
KernelTimer timer;
timer.Start();
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
constexpr index_t ytilda = decltype(ytilda_){};
constexpr index_t xtilda = decltype(xtilda_){};
constexpr auto gridwise_conv =
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
GridSize,
BlockSize,
T,
T,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ytilda,
xtilda,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmMPerThreadSubC,
GemmNPerThreadSubC,
GemmMLevel0Cluster,
GemmNLevel0Cluster,
GemmMLevel1Cluster,
GemmNLevel1Cluster,
GemmKPerThreadLoop,
GemmThreadGemmDataPerReadM,
GemmThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
GemmABlockCopySrcDataPerRead_GemmM,
GemmABlockCopyDstDataPerWrite_GemmM,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmN,
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
gridwise_conv,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
});
static_for<0, GridwiseConv::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) {
constexpr index_t gemm_id = decltype(gemm_id_){};
launch_kernel(run_gridwise_convolution_backward_data_v4r1<GridwiseConv,
gemm_id,
T* const __restrict__,
const T* const __restrict__,
const T* const __restrict__>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
});
timer.End();
float time = timer.GetElapsedTime();
printf("Elapsed time : %f ms, %f TFlop/s\n",
......
......@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
#if 1
#if 0
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
constexpr index_t BlockSize = 256;
......@@ -127,7 +127,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#elif 1
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
// for 1x1
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 16;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<4, 1, 1, 2>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
#elif 1
// BlockSize = 64, each thread hold 64 data
constexpr index_t BlockSize = 64;
......
......@@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 1
#elif 0
// BlockSize = 256, GemmKPerBlock = 16
constexpr index_t BlockSize = 256;
......@@ -117,7 +117,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
#elif 0
// BlockSize = 256, GemmKPerBlock = 8
// 1x1 filter, 8x8 image
// for 1x1 filter, vector-read-b = 4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
......@@ -149,7 +149,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
#elif 1
// BlockSize = 256, GemmKPerBlock = 16
// 1x1 filter, 8x8 image
// for 1x1 filter, vector-read-b = 4
constexpr index_t BlockSize = 256;
constexpr index_t GemmMPerBlock = 128;
......
......@@ -161,10 +161,10 @@ int main(int argc, char* argv[])
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t C = 128;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 1024;
constexpr index_t K = 128;
constexpr index_t Y = 1;
constexpr index_t X = 7;
......@@ -246,28 +246,28 @@ int main(int argc, char* argv[])
#endif
}
#if 0
#if 1
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
#elif 1
#elif 0
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
#elif 0
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
#elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif
(in_nchw_desc,
in_nchw_device,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
(in_nchw_desc,
in_nchw_device,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
if(do_verification)
{
......
......@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
{
using namespace ck;
#if 0
#if 1
// 1x1
constexpr index_t N = 256;
constexpr index_t C = 1024;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 1024;
constexpr index_t N = 64;
constexpr index_t C = 64;
constexpr index_t HI = 56;
constexpr index_t WI = 56;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment