Commit 3317bfe2 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent 2b8e3ece
......@@ -139,7 +139,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
constexpr auto out_gemmm_gemmn_global_desc =
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
......
......@@ -219,12 +219,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single half to 4 packed half/2 packed bfloat16
// respectively.
auto p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_now);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_now);
auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_now);
auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_now);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store next data to LDS
......@@ -252,12 +248,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer);
// LDS double buffer: GEMM on 2nd-last data
auto p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_double);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double);
auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double);
auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
// LDS double buffer: store last data to LDS
......@@ -269,12 +261,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
__syncthreads();
// LDS double buffer: GEMM on current data
p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_double + a_block_space);
p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double + b_block_space);
p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double + a_block_space);
p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double + b_block_space);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
}
else // if has 1 iteration left
......@@ -282,12 +270,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
__syncthreads();
// LDS double buffer: GEMM on last data
auto p_a_block_vec =
reinterpret_cast<const half4_t*>(
p_a_block_double);
auto p_b_block_vec =
reinterpret_cast<const half4_t*>(
p_b_block_double);
auto p_a_block_vec = reinterpret_cast<const half4_t*>(p_a_block_double);
auto p_b_block_vec = reinterpret_cast<const half4_t*>(p_b_block_double);
blockwise_gemm.Run(p_a_block_vec, p_b_block_vec, p_c_thread);
}
}
......@@ -348,7 +332,6 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
}
}
};
}
#endif
......@@ -810,7 +810,7 @@ struct XdlopsGemm_t
(mfma_type.group_size * mfma_type.num_input_blks);
index_t bindex = blk_td;
p_c_thread[m + c_off] += inner_product_with_conversion<FloatC>{}(
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
p_a_wave[aindex + a_off], p_b_wave[bindex + b_off]);
}
}
}
......
......@@ -13,16 +13,16 @@ template <class T,
class InLeftPads,
class InRightPads>
void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
const Tensor<T>& in_nchw,
WeiDesc,
const Tensor<T>& wei_kcyx,
OutDesc,
Tensor<T>& out_nkhw,
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
ck::index_t nrepeat)
{
using namespace ck;
......@@ -60,7 +60,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmKPerBlock = 4;
constexpr index_t GemmKPACK = 4;
constexpr index_t GemmMPerWave = 64;
constexpr index_t GemmNPerWave = 64;
......@@ -76,7 +76,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK = Sequence<1, 2, 4>;
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK = Sequence<4, 32, 1>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmKPACK = 1;
constexpr index_t GemmM = K;
......@@ -87,51 +87,52 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw<
GridSize,
BlockSize,
half,
float,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmKPACK,
GemmMPerWave,
GemmNPerWave,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK,
GemmABlockCopySrcDataPerRead_GemmKPACK,
GemmABlockCopyDstDataPerWrite_GemmKPACK,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPACK>{};
constexpr auto gridwise_conv =
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw<
GridSize,
BlockSize,
half,
float,
decltype(in_nchw_desc),
decltype(wei_kcyx_desc),
decltype(out_nkhw_desc),
ConvStrides,
ConvDilations,
InLeftPads,
InRightPads,
GemmMPerBlock,
GemmNPerBlock,
GemmKPerBlock,
GemmKPACK,
GemmMPerWave,
GemmNPerWave,
ThreadGemmDataPerReadM,
ThreadGemmDataPerReadN,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK,
GemmABlockCopySrcDataPerRead_GemmKPACK,
GemmABlockCopyDstDataPerWrite_GemmKPACK,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK,
GemmBBlockCopySrcDataPerRead_GemmN,
GemmBBlockCopyDstDataPerWrite_GemmKPACK>{};
for(index_t i = 0; i < 10; ++i)
{
float time =
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
printf("Elapsed time : %f ms, %f TFlop/s\n",
time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / time);
}
// warm up
......@@ -139,14 +140,14 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i)
{
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
printf("Start running %d times...\n", nrepeat);
......@@ -156,25 +157,25 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
for(index_t i = 0; i < nrepeat; ++i)
{
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
}
cudaDeviceSynchronize();
auto end = std::chrono::steady_clock::now();
float ave_time = std::chrono::duration<float, std::milli>(end - start).count() / nrepeat;
printf("Average elapsed time : %f ms, %f TFlop/s\n",
ave_time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time);
printf("Average elapsed time : %f ms, %f TFlop/s\n",
ave_time,
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
(std::size_t(1000) * 1000 * 1000) / ave_time);
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
}
......@@ -618,16 +618,16 @@ int main(int argc, char* argv[])
nrepeat);
#elif 1
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
in_nchw,
wei_kcyx_desc,
wei_kcyx,
out_nkhw_desc,
out_nkhw_device,
ConvStrides{},
ConvDilations{},
LeftPads{},
RightPads{},
nrepeat);
#endif
if(do_verification)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment