Commit f885c131 authored by Chao Liu's avatar Chao Liu
Browse files

tidy

parent 80120f0a
...@@ -20,6 +20,12 @@ __device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p) ...@@ -20,6 +20,12 @@ __device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p)
return (T*)p; return (T*)p;
} }
template <typename T>
__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p)
{
return (T CONSTANT*)p;
}
} // namespace ck } // namespace ck
#endif #endif
...@@ -11,59 +11,11 @@ namespace ck { ...@@ -11,59 +11,11 @@ namespace ck {
template <typename T> template <typename T>
__host__ __device__ void print_array(const char* s, T a) __host__ __device__ void print_array(const char* s, T a)
{ {
using data_type = decltype(a.At(Number<0>{}));
constexpr index_t nsize = a.Size(); constexpr index_t nsize = a.Size();
#if 0
if constexpr(is_same<data_type, uint32_t>{})
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, int32_t>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, bool>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); });
printf("}\n");
}
#else
printf("%s size %d, {", s, nsize); printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n"); printf("}\n");
#endif
}
template <typename T>
__host__ __device__ void print_array_v2(const char* s, T a)
{
using data_type = decltype(a.At(Number<0>{}));
constexpr index_t nsize = a.Size();
#if 0
if constexpr(is_same<data_type, uint32_t>{})
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); });
printf("}\n");
}
else if constexpr(is_same<data_type, int32_t>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
printf("}\n");
}
#else
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
printf("}\n");
#endif
} }
} // namespace ck } // namespace ck
......
...@@ -257,9 +257,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw ...@@ -257,9 +257,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
const auto K = out_n_ho_wo_k_lengths[I3]; const auto K = out_n_ho_wo_k_lengths[I3];
const auto C = wei_k_y_x_c_lengths[I3]; const auto C = wei_k_y_x_c_lengths[I3];
const auto Hi = in_n_hi_wi_c_lengths[I1];
const auto Wi = in_n_hi_wi_c_lengths[I2];
const auto Ho = out_n_ho_wo_k_lengths[I1]; const auto Ho = out_n_ho_wo_k_lengths[I1];
const auto Wo = out_n_ho_wo_k_lengths[I2]; const auto Wo = out_n_ho_wo_k_lengths[I2];
......
...@@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
......
...@@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad ...@@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
wei_e_k_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
...@@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad ...@@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
wei_e_k_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
...@@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad ...@@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
wei_e_k_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
...@@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad ...@@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
wei_e_k_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
......
...@@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
wei_e_k_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
...@@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
wei_e_k_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
...@@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
wei_e_k_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
...@@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp ...@@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3(GridSize), dim3(GridSize),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
wei_e_k_global_desc, wei_e_k_global_desc,
p_wei_global, p_wei_global,
in_e_n_ho_wo_global_desc, in_e_n_ho_wo_global_desc,
......
...@@ -189,7 +189,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -189,7 +189,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -216,7 +215,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -216,7 +215,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -243,7 +241,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -243,7 +241,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -270,7 +267,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -270,7 +267,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -315,14 +311,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -315,14 +311,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
...@@ -343,14 +340,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -343,14 +340,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
...@@ -371,14 +369,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -371,14 +369,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
else else
{ {
...@@ -399,14 +398,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, ...@@ -399,14 +398,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
return ave_time; return ave_time;
......
...@@ -185,7 +185,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -185,7 +185,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -212,7 +211,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -212,7 +211,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -239,7 +237,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -239,7 +237,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -266,7 +263,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -266,7 +263,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -311,14 +307,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -311,14 +307,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
...@@ -339,14 +338,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -339,14 +338,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
...@@ -367,14 +369,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -367,14 +369,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
else else
{ {
...@@ -395,14 +400,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, ...@@ -395,14 +400,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(
(void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
} }
return ave_time; return ave_time;
......
...@@ -153,7 +153,6 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -153,7 +153,6 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -173,20 +172,19 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, ...@@ -173,20 +172,19 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = float ave_time = launch_and_time_kernel(
launch_and_time_kernel(kernel, kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
(void CONSTANT*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(), cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()),
(void CONSTANT*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif #endif
return ave_time; return ave_time;
} }
......
...@@ -142,10 +142,8 @@ int main(int argc, char* argv[]) ...@@ -142,10 +142,8 @@ int main(int argc, char* argv[])
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
switch(layout) if(layout == ConvTensorLayout::NCHW)
{ {
case ConvTensorLayout::NCHW:
// NCHW
in_lengths_host[0] = static_cast<std::size_t>(N); in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(C); in_lengths_host[1] = static_cast<std::size_t>(C);
in_lengths_host[2] = static_cast<std::size_t>(Hi); in_lengths_host[2] = static_cast<std::size_t>(Hi);
...@@ -158,9 +156,9 @@ int main(int argc, char* argv[]) ...@@ -158,9 +156,9 @@ int main(int argc, char* argv[])
out_lengths_host[1] = static_cast<std::size_t>(K); out_lengths_host[1] = static_cast<std::size_t>(K);
out_lengths_host[2] = static_cast<std::size_t>(Ho); out_lengths_host[2] = static_cast<std::size_t>(Ho);
out_lengths_host[3] = static_cast<std::size_t>(Wo); out_lengths_host[3] = static_cast<std::size_t>(Wo);
break; }
case ConvTensorLayout::NHWC: else if(layout == ConvTensorLayout::NHWC)
// NHWC {
in_lengths_host[0] = static_cast<std::size_t>(N); in_lengths_host[0] = static_cast<std::size_t>(N);
in_lengths_host[1] = static_cast<std::size_t>(Hi); in_lengths_host[1] = static_cast<std::size_t>(Hi);
in_lengths_host[2] = static_cast<std::size_t>(Wi); in_lengths_host[2] = static_cast<std::size_t>(Wi);
...@@ -173,8 +171,10 @@ int main(int argc, char* argv[]) ...@@ -173,8 +171,10 @@ int main(int argc, char* argv[])
out_lengths_host[1] = static_cast<std::size_t>(Ho); out_lengths_host[1] = static_cast<std::size_t>(Ho);
out_lengths_host[2] = static_cast<std::size_t>(Wo); out_lengths_host[2] = static_cast<std::size_t>(Wo);
out_lengths_host[3] = static_cast<std::size_t>(K); out_lengths_host[3] = static_cast<std::size_t>(K);
break; }
default: throw std::runtime_error("wrong! not implemented"); else
{
std::runtime_error("wrong! not implemented");
} }
Tensor<in_data_t> in(in_lengths_host); Tensor<in_data_t> in(in_lengths_host);
......
...@@ -34,24 +34,16 @@ struct KernelTimer ...@@ -34,24 +34,16 @@ struct KernelTimer
using device_stream_t = hipStream_t; using device_stream_t = hipStream_t;
template <typename... Args, typename F> template <typename... Args, typename F>
void launch_kernel(F kernel, void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
Args... args)
{ {
hipStream_t stream_id = nullptr;
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
} }
template <typename... Args, typename F> template <typename... Args, typename F>
float launch_and_time_kernel(F kernel, float launch_and_time_kernel(
int nrepeat, F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
hipStream_t stream_id,
Args... args)
{ {
KernelTimer timer; KernelTimer timer;
...@@ -66,6 +58,8 @@ float launch_and_time_kernel(F kernel, ...@@ -66,6 +58,8 @@ float launch_and_time_kernel(F kernel,
printf("Warm up\n"); printf("Warm up\n");
hipStream_t stream_id = nullptr;
// warm up // warm up
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment