Commit f885c131 authored by Chao Liu's avatar Chao Liu
Browse files

tidy

parent 80120f0a
......@@ -20,6 +20,12 @@ __device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p)
return (T*)p;
}
template <typename T>
__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p)
{
return (T CONSTANT*)p;
}
} // namespace ck
#endif
......@@ -11,59 +11,11 @@ namespace ck {
template <typename T>
__host__ __device__ void print_array(const char* s, T a)
{
using data_type = decltype(a.At(Number<0>{}));
constexpr index_t nsize = a.Size();
#if 0
if constexpr(is_same<data_type, uint32_t>{})
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, int32_t>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, bool>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); });
printf("}\n");
}
#else
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n");
#endif
}
template <typename T>
__host__ __device__ void print_array_v2(const char* s, T a)
{
using data_type = decltype(a.At(Number<0>{}));
constexpr index_t nsize = a.Size();
#if 0
if constexpr(is_same<data_type, uint32_t>{})
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); });
printf("}\n");
}
else if constexpr(is_same<data_type, int32_t>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
printf("}\n");
}
#else
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
printf("}\n");
#endif
}
} // namespace ck
......
......@@ -257,9 +257,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2];
......
......@@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......
......@@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
......@@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
......@@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
......@@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
......
......@@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
......@@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
......@@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
......@@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3(GridSize),
dim3(BlockSize),
0,
0,
wei_e_k_global_desc,
p_wei_global,
in_e_n_ho_wo_global_desc,
......
......@@ -189,7 +189,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -216,7 +215,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -243,7 +241,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -270,7 +267,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -315,14 +311,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
......@@ -343,14 +340,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
......@@ -371,14 +369,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else
{
......@@ -399,14 +398,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
return ave_time;
......
......@@ -185,7 +185,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -212,7 +211,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -239,7 +237,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -266,7 +263,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -311,14 +307,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
cast_pointer_to_constant_address_space(
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
......@@ -339,14 +338,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
cast_pointer_to_constant_address_space(
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
......@@ -367,14 +369,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
cast_pointer_to_constant_address_space(
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
else
{
......@@ -395,14 +400,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
cast_pointer_to_constant_address_space(
a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
}
return ave_time;
......
......@@ -153,7 +153,6 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
......@@ -173,20 +172,19 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time =
launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
0,
p_a_grid,
p_b_grid,
p_c_grid,
(void CONSTANT*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(),
(void CONSTANT*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
float ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif
return ave_time;
}
......
......@@ -142,10 +142,8 @@ int main(int argc, char* argv[])
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
switch(layout)
if(layout == ConvTensorLayout::NCHW)
{
case ConvTensorLayout::NCHW:
// NCHW
in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(C);
in_lengths_host[2] = static_cast<std::size_t>(Hi);
......@@ -158,9 +156,9 @@ int main(int argc, char* argv[])
out_lengths_host[1] = static_cast<std::size_t>(K);
out_lengths_host[2] = static_cast<std::size_t>(Ho);
out_lengths_host[3] = static_cast<std::size_t>(Wo);
break;
case ConvTensorLayout::NHWC:
// NHWC
}
else if(layout == ConvTensorLayout::NHWC)
{
in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(Hi);
in_lengths_host[2] = static_cast<std::size_t>(Wi);
......@@ -173,8 +171,10 @@ int main(int argc, char* argv[])
out_lengths_host[1] = static_cast<std::size_t>(Ho);
out_lengths_host[2] = static_cast<std::size_t>(Wo);
out_lengths_host[3] = static_cast<std::size_t>(K);
break;
default: throw std::runtime_error("wrong! not implemented");
}
else
{
std::runtime_error("wrong! not implemented");
}
Tensor<in_data_t> in(in_lengths_host);
......
......@@ -34,24 +34,16 @@ struct KernelTimer
using device_stream_t = hipStream_t;
template <typename... Args, typename F>
void launch_kernel(F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
Args... args)
void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
hipStream_t stream_id = nullptr;
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
}
template <typename... Args, typename F>
float launch_and_time_kernel(F kernel,
int nrepeat,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
Args... args)
float launch_and_time_kernel(
F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
KernelTimer timer;
......@@ -66,6 +58,8 @@ float launch_and_time_kernel(F kernel,
printf("Warm up\n");
hipStream_t stream_id = nullptr;
// warm up
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment