"...composable_kernel_rocm.git" did not exist on "7dcb14d9d495e6329477f5ee27a27cfbb4ce49a6"
Commit f66a71c7 authored by Jing Zhang's avatar Jing Zhang
Browse files

make static

parent 4e5e68a1
...@@ -322,10 +322,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -322,10 +322,10 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2);
const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3);
const auto K0 = K / KPerBlock; const auto K0 = Number<K / KPerBlock>{};
const auto N0 = N / NPerBlock; const auto N0 = Number<N / NPerBlock>{};
const auto H0 = Ho / HoPerBlock; const auto H0 = Number<Ho / HoPerBlock>{};
const auto W0 = Wo / WoPerBlock; const auto W0 = Number<Wo / WoPerBlock>{};
const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor( const auto c_blockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))), make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))),
...@@ -353,23 +353,34 @@ struct GridwiseGemmDlops_km_kn_mn_v3 ...@@ -353,23 +353,34 @@ struct GridwiseGemmDlops_km_kn_mn_v3
return make_naive_tensor_descriptor_packed(make_tuple(K0, K1)); return make_naive_tensor_descriptor_packed(make_tuple(K0, K1));
} }
template <bool HasMainE0BlockLoop> template <bool HasMainE0BlockLoop_>
__device__ static void __device__ static void
Run(const FloatAB* __restrict__ p_a_global, Run(const FloatAB* __restrict__ p_a_global,
const FloatAB* __restrict__ p_b_global, const FloatAB* __restrict__ p_b_global,
const FloatC* __restrict__ p_bias_global, const FloatC* __restrict__ p_bias_global,
FloatC* __restrict__ p_c_global, FloatC* __restrict__ p_c_global,
FloatAB* __restrict__ p_shared_block, FloatAB* __restrict__ p_shared_block,
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc_,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc_,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc_,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor_,
integral_constant<bool, HasMainE0BlockLoop>) integral_constant<bool, HasMainE0BlockLoop_>)
{ {
const auto bias_k0_k1_grid_desc = constexpr auto a_e0_e1_k0_k1_e2_grid_desc = AGridDesc_E0_E1_K0_K1_E2{};
constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2{};
constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2{};
constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
CBlockIdToBlockClusterAdaptor_K_N_H_W{};
constexpr auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc);
constexpr bool HasMainE0BlockLoop =
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0) > 1;
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
......
...@@ -328,6 +328,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -328,6 +328,11 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
float ave_time = 0; float ave_time = 0;
static_assert(a_e0_e1_k0_k1_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.IsKnownAtCompileTime(), "");
static_assert(c_blockid_to_k_n_h_w_block_cluster_adaptor.IsKnownAtCompileTime(), "");
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
if(has_main_e0_block_loop) if(has_main_e0_block_loop)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment