Commit 3317bfe2 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 2b8e3ece
...@@ -139,7 +139,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw ...@@ -139,7 +139,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}), make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto out_gemmm_gemmn_global_desc = constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc, transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}), make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
......
...@@ -219,12 +219,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1 ...@@ -219,12 +219,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4), // 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single half to 4 packed half/2 packed bfloat16 // we recast datatype from a single half to 4 packed half/2 packed bfloat16
// respectively. // respectively.
auto p_a_block_vec = auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_now);
reinterpret_cast<const half4_t*>( auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_now);
p_a_block_now);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_now);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store next data to LDS // LDS double buffer: store next data to LDS
...@@ -252,12 +248,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1 ...@@ -252,12 +248,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data // LDS double buffer: GEMM on 2nd-last data
auto p_a_block_vec = auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double);
reinterpret_cast<const half4_t*>( auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double);
p_a_block_double);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store last data to LDS // LDS double buffer: store last data to LDS
...@@ -269,12 +261,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1 ...@@ -269,12 +261,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on current data // LDS double buffer: GEMM on current data
p_a_block_vec = p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double + a_block_space);
reinterpret_cast<const half4_t*>( p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double + b_block_space);
p_a_block_double + a_block_space);
p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double + b_block_space);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
} }
else // if has 1 iteration left else // if has 1 iteration left
...@@ -282,12 +270,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1 ...@@ -282,12 +270,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
__syncthreads(); __syncthreads();
// LDS double buffer: GEMM on last data // LDS double buffer: GEMM on last data
auto p_a_block_vec = auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double);
reinterpret_cast<const half4_t*>( auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double);
p_a_block_double);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread); blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
} }
} }
...@@ -348,7 +332,6 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1 ...@@ -348,7 +332,6 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
} }
} }
}; };
} }
#endif #endif
...@@ -810,7 +810,7 @@ struct XdlopsGemm_t ...@@ -810,7 +810,7 @@ struct XdlopsGemm_t
(mfma_type.group_size * mfma_type.num_input_blks); (mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td; index_t bindex = blk_td;
p_c_thread[m + c_off] += inner_product_with_conversion<FloatC>{}( p_c_thread[m + c_off] += inner_product_with_conversion<FloatC>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]); p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
} }
} }
} }
......
...@@ -13,16 +13,16 @@ template <class T, ...@@ -13,16 +13,16 @@ template <class T,
class InLeftPads, class InLeftPads,
class InRightPads> class InRightPads>
void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc, void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw, const Tensor<T>& in_nchw,
WeiDesc, WeiDesc,
const Tensor<T>& wei_kcyx, const Tensor<T>& wei_kcyx,
OutDesc, OutDesc,
Tensor<T>& out_nkhw, Tensor<T>& out_nkhw,
ConvStrides, ConvStrides,
ConvDilations, ConvDilations,
InLeftPads, InLeftPads,
InRightPads, InRightPads,
ck::index_t nrepeat) ck::index_t nrepeat)
{ {
using namespace ck; using namespace ck;
...@@ -60,7 +60,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc, ...@@ -60,7 +60,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmKPerBlock = 4; constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmKPACK = 4; constexpr index_t GemmKPACK = 4;
constexpr index_t GemmMPerWave = 64; constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64; constexpr index_t GemmNPerWave = 64;
...@@ -76,7 +76,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc, ...@@ -76,7 +76,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK = Sequence<1, 2, 4>; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK = Sequence<1, 2, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK = Sequence<4, 32, 1>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK = Sequence<4, 32, 1>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPACK = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPACK = 1;
constexpr index_t GemmM = K; constexpr index_t GemmM = K;
...@@ -87,51 +87,52 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc, ...@@ -87,51 +87,52 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw< constexpr auto gridwise_conv =
GridSize, GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw<
BlockSize, GridSize,
half, BlockSize,
float, half,
decltype(in_nchw_desc), float,
decltype(wei_kcyx_desc), decltype(in_nchw_desc),
decltype(out_nkhw_desc), decltype(wei_kcyx_desc),
ConvStrides, decltype(out_nkhw_desc),
ConvDilations, ConvStrides,
InLeftPads, ConvDilations,
InRightPads, InLeftPads,
GemmMPerBlock, InRightPads,
GemmNPerBlock, GemmMPerBlock,
GemmKPerBlock, GemmNPerBlock,
GemmKPACK, GemmKPerBlock,
GemmMPerWave, GemmKPACK,
GemmNPerWave, GemmMPerWave,
ThreadGemmDataPerReadM, GemmNPerWave,
ThreadGemmDataPerReadN, ThreadGemmDataPerReadM,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK, ThreadGemmDataPerReadN,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK, GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK,
GemmABlockCopySrcDataPerRead_GemmKPACK, GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK,
GemmABlockCopyDstDataPerWrite_GemmKPACK, GemmABlockCopySrcDataPerRead_GemmKPACK,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK, GemmABlockCopyDstDataPerWrite_GemmKPACK,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK,
GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK,
GemmBBlockCopyDstDataPerWrite_GemmKPACK>{}; GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPACK>{};
for(index_t i = 0; i < 10; ++i) for(index_t i = 0; i < 10; ++i)
{ {
float time = float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n", printf("Elapsed time : %f ms, %f TFlop/s\n",
time, time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time); (std::size_t(1000) * 1000 * 1000) / time);
} }
// warm up // warm up
...@@ -139,14 +140,14 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc, ...@@ -139,14 +140,14 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
} }
printf("Start running %d times...\n", nrepeat); printf("Start running %d times...\n", nrepeat);
...@@ -156,25 +157,25 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc, ...@@ -156,25 +157,25 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i) for(index_t i = 0; i < nrepeat; ++i)
{ {
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>, launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0, 0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()), static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()), static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())); static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
} }
cudaDeviceSynchronize(); cudaDeviceSynchronize();
auto end = std::chrono::steady_clock::now(); auto end = std::chrono::steady_clock::now();
float ave_time = std::chrono::duration<float, std::milli>(end - start).count() / nrepeat; float ave_time = std::chrono::duration<float, std::milli>(end - start).count() / nrepeat;
printf("Average elapsed time : %f ms, %f TFlop/s\n", printf("Average elapsed time : %f ms, %f TFlop/s\n",
ave_time, ave_time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time); (std::size_t(1000) * 1000 * 1000) / ave_time);
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
} }
...@@ -618,16 +618,16 @@ int main(int argc, char* argv[]) ...@@ -618,16 +618,16 @@ int main(int argc, char* argv[])
nrepeat); nrepeat);
#elif 1 #elif 1
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(in_nchw_desc, device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw, in_nchw,
wei_kcyx_desc, wei_kcyx_desc,
wei_kcyx, wei_kcyx,
out_nkhw_desc, out_nkhw_desc,
out_nkhw_device, out_nkhw_device,
ConvStrides{}, ConvStrides{},
ConvDilations{}, ConvDilations{},
LeftPads{}, LeftPads{},
RightPads{}, RightPads{},
nrepeat); nrepeat);
#endif #endif
if(do_verification) if(do_verification)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment