Commit 0f276ac2 authored by Jing Zhang's avatar Jing Zhang
Browse files

add configurable makeddesc

parent 35a57947
...@@ -18,7 +18,7 @@ template <typename GridwiseGemm, ...@@ -18,7 +18,7 @@ template <typename GridwiseGemm,
typename AGridDesc_E0_E1_K0_K1_E2, typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2, typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2, typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop> bool HasMainE0BlockLoop>
__global__ void __global__ void
...@@ -34,7 +34,7 @@ __global__ void ...@@ -34,7 +34,7 @@ __global__ void
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor) const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor)
{ {
constexpr index_t shared_block_size = constexpr index_t shared_block_size =
...@@ -51,7 +51,7 @@ __global__ void ...@@ -51,7 +51,7 @@ __global__ void
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor, c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}); integral_constant<bool, HasMainE0BlockLoop>{});
} }
...@@ -65,7 +65,7 @@ template <typename GridwiseGemm, ...@@ -65,7 +65,7 @@ template <typename GridwiseGemm,
typename AGridDesc_E0_E1_K0_K1_E2, typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2, typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2, typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2, typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W, typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop> bool HasMainE0BlockLoop>
__global__ void __global__ void
...@@ -80,7 +80,7 @@ __global__ void ...@@ -80,7 +80,7 @@ __global__ void
const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc, const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc,
const void CONSTANT* p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const void CONSTANT* p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const void CONSTANT* p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, const void CONSTANT* p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const void CONSTANT* p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, const void CONSTANT* p_d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const void CONSTANT* p_c_blockid_to_k_n_h_w_block_cluster_adaptor) const void CONSTANT* p_c_blockid_to_k_n_h_w_block_cluster_adaptor)
{ {
// first cast void CONSTANT void* to void* // first cast void CONSTANT void* to void*
...@@ -94,9 +94,9 @@ __global__ void ...@@ -94,9 +94,9 @@ __global__ void
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
*reinterpret_cast<const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2*>( *reinterpret_cast<const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2*>(
cast_pointer_to_generic_address_space(p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc)); cast_pointer_to_generic_address_space(p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc));
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc =
*reinterpret_cast<const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2*>( *reinterpret_cast<const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx*>(
cast_pointer_to_generic_address_space(p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc)); cast_pointer_to_generic_address_space(p_d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc));
const auto c_blockid_to_k_n_h_w_block_cluster_adaptor = const auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToBlockClusterAdaptor_K_N_H_W*>( *reinterpret_cast<const CBlockIdToBlockClusterAdaptor_K_N_H_W*>(
cast_pointer_to_generic_address_space(p_c_blockid_to_k_n_h_w_block_cluster_adaptor)); cast_pointer_to_generic_address_space(p_c_blockid_to_k_n_h_w_block_cluster_adaptor));
...@@ -115,7 +115,7 @@ __global__ void ...@@ -115,7 +115,7 @@ __global__ void
a_e0_e1_k0_k1_e2_grid_desc, a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor, c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{}); integral_constant<bool, HasMainE0BlockLoop>{});
} }
...@@ -129,7 +129,7 @@ template <index_t BlockSize, ...@@ -129,7 +129,7 @@ template <index_t BlockSize,
typename AGridDesc_E0_E1_K_E2, typename AGridDesc_E0_E1_K_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2, typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo, typename CGridDesc_K_N_Ho_Wo,
typename DGridDesc_K_N_Hox2_Wox2, typename DGridDesc_K_N_Hx_Wx,
index_t E1_, index_t E1_,
index_t E2_, index_t E2_,
index_t K2_, index_t K2_,
...@@ -162,7 +162,8 @@ template <index_t BlockSize, ...@@ -162,7 +162,8 @@ template <index_t BlockSize,
typename DGlobalStepHacks, typename DGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks, typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks, typename BGlobalMoveSliceWindowStepHacks,
index_t activ_type = 0> index_t activ_type = 0,
index_t add_type = 0>
struct GridwiseGemmDlops_km_kn_mn_v3_add struct GridwiseGemmDlops_km_kn_mn_v3_add
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -327,27 +328,58 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -327,27 +328,58 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc; return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc;
} }
__host__ __device__ static constexpr auto MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor( __host__ __device__ static constexpr auto
const DGridDesc_K_N_Hox2_Wox2& d_k_n_hox2_wox2_grid_desc) MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{ {
const auto K = d_k_n_hox2_wox2_grid_desc.GetLength(I0); const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0);
const auto N = d_k_n_hox2_wox2_grid_desc.GetLength(I1); const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1);
const auto Hox2 = d_k_n_hox2_wox2_grid_desc.GetLength(I2); const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2);
const auto Wox2 = d_k_n_hox2_wox2_grid_desc.GetLength(I3); const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto H2 = HoPerThread / 2;
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Hx / (H1 * H2);
const auto W2 = WoPerThread / 2;
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wx / (W1 * W2);
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor(
d_k_n_hx_wx_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc;
}
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0);
const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1);
const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2);
const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{}; const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1; const auto K0 = K / K1;
const auto H2 = HoPerThread * 2; const auto H2 = HoPerThread * 2;
const auto H1 = Number<HoPerBlock / HoPerThread>{}; const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Hox2 / (H1 * H2); const auto H0 = Hx / (H1 * H2);
const auto W2 = WoPerThread * 2; const auto W2 = WoPerThread * 2;
const auto W1 = Number<WoPerBlock / WoPerThread>{}; const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wox2 / (W1 * W2); const auto W0 = Wx / (W1 * W2);
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = transform_tensor_descriptor( const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor(
d_k_n_hox2_wox2_grid_desc, d_k_n_hx_wx_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N), make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)), make_unmerge_transform(make_tuple(H0, H1, H2)),
...@@ -355,7 +387,24 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -355,7 +387,24 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc; return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc;
}
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptor(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
if constexpr(add_type == 0)
{
return MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(d_k_n_hx_wx_grid_desc);
}
else if constexpr(add_type == 1)
{
return MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc);
}
else
{
return MakeCK0K1NH0H1H2W0W1W2GridDescriptor(d_k_n_hx_wx_grid_desc);
}
} }
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
...@@ -385,17 +434,17 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -385,17 +434,17 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{})); decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{})); decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 = using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx =
decltype(MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(DGridDesc_K_N_Hox2_Wox2{})); decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{}));
using CBlockIdToBlockClusterAdaptor_K_N_H_W = using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{})); decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{}));
__host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor( __host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor(
const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2& d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc) const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc)
{ {
const auto K0 = d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.GetLength(I0); const auto K0 = d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetLength(I0);
const auto K1 = d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.GetLength(I1); const auto K1 = d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetLength(I1);
return make_naive_tensor_descriptor_packed(make_tuple(K0, K1)); return make_naive_tensor_descriptor_packed(make_tuple(K0, K1));
} }
...@@ -411,7 +460,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -411,7 +460,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2& d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor, const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>) integral_constant<bool, HasMainE0BlockLoop>)
{ {
...@@ -419,13 +468,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -419,13 +468,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
// constexpr auto a_e0_e1_k0_k1_e2_grid_desc = AGridDesc_E0_E1_K0_K1_E2{}; // constexpr auto a_e0_e1_k0_k1_e2_grid_desc = AGridDesc_E0_E1_K0_K1_E2{};
// constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = // constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
// BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2{}; // BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2{};
// constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = // constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc =
// DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2{}; // DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx{};
// constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor = // constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
// CBlockIdToBlockClusterAdaptor_K_N_H_W{}; // CBlockIdToBlockClusterAdaptor_K_N_H_W{};
const auto bias_k0_k1_grid_desc = const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc); MakeBiasK0K1GridDescriptor(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
...@@ -434,7 +483,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -434,7 +483,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_d_global, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.GetElementSpaceSize()); p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>( auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
...@@ -933,7 +982,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -933,7 +982,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
constexpr auto WoPerThreadx2 = WoPerThread * 2; constexpr auto WoPerThreadx2 = WoPerThread * 2;
#if 1 #if 1
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc = constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1, make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{}, Number<KPerThread>{},
I1, I1,
...@@ -946,16 +995,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -946,16 +995,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
StaticBuffer<AddressSpaceEnum_t::Vgpr, StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatC, FloatC,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc.GetElementSpaceSize(), d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true> true>
d_thread_buf; d_thread_buf;
static_for<0, KPerThread, 1>{}([&](auto k_i) { static_for<0, KPerThread, 1>{}([&](auto k_i) {
static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) { static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) {
static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) { static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) {
d_thread_buf( d_thread_buf(Number<d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset(
Number<d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc.CalculateOffset( make_tuple(0, k_i, 0, 0, 0, h_i, 0, 0, w_i))>{}) =
make_tuple(0, k_i, 0, 0, 0, h_i, 0, 0, w_i))>{}) =
c_thread_buf[Number<c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( c_thread_buf[Number<c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(k_i, 0, h_i / 2, w_i / 2))>{}]; make_tuple(k_i, 0, h_i / 2, w_i / 2))>{}];
}); });
...@@ -974,58 +1022,56 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -974,58 +1022,56 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
I1, I1,
Number<WoPerThread>{})); Number<WoPerThread>{}));
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc = constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = transform_tensor_descriptor(
transform_tensor_descriptor( c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_desc, make_tuple(
make_tuple(make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(Number<KPerThread>{}), make_pass_through_transform(Number<KPerThread>{}),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_embed_transform(make_tuple(I2, Number<HoPerThread>{}), make_embed_transform(make_tuple(I2, Number<HoPerThread>{}), make_tuple(I0, I1)),
make_tuple(I0, I1)), make_pass_through_transform(I1),
make_pass_through_transform(I1), make_pass_through_transform(I1),
make_pass_through_transform(I1), make_embed_transform(make_tuple(I2, Number<WoPerThread>{}),
make_embed_transform(make_tuple(I2, Number<WoPerThread>{}), make_tuple(I0, I1))),
make_tuple(I0, I1))), make_tuple(Sequence<0>{},
make_tuple(Sequence<0>{}, Sequence<1>{},
Sequence<1>{}, Sequence<2>{},
Sequence<2>{}, Sequence<3>{},
Sequence<3>{}, Sequence<4>{},
Sequence<4>{}, Sequence<5>{},
Sequence<5>{}, Sequence<6>{},
Sequence<6>{}, Sequence<7>{},
Sequence<7>{}, Sequence<8>{}),
Sequence<8>{}), make_tuple(Sequence<0>{},
make_tuple(Sequence<0>{}, Sequence<1>{},
Sequence<1>{}, Sequence<2>{},
Sequence<2>{}, Sequence<3>{},
Sequence<3>{}, Sequence<4>{},
Sequence<4>{}, Sequence<5, 6>{},
Sequence<5, 6>{}, Sequence<7>{},
Sequence<7>{}, Sequence<8>{},
Sequence<8>{}, Sequence<9, 10>{}));
Sequence<9, 10>{}));
#endif #endif
// hack to control index calculation when iterating over d_k_n_ho_wo_global tensor // hack to control index calculation when iterating over d_k_n_ho_wo_global tensor
constexpr auto d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks = constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{};
DGlobalStepHacks{};
const index_t k_thread_data_on_global = k_thread_id * KPerThread; const index_t k_thread_data_on_global = k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3< ThreadwiseTensorSliceTransfer_v1r3<
FloatC, FloatC,
FloatC, FloatC,
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc), decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc),
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc), decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThreadx2, I1, I1, WoPerThreadx2>, Sequence<I1, KPerThread, I1, I1, I1, HoPerThreadx2, I1, I1, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
1, 1,
true>(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id, make_multi_index(k_block_work_id,
k_thread_data_on_global, k_thread_data_on_global,
n_block_work_id, n_block_work_id,
...@@ -1035,12 +1081,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add ...@@ -1035,12 +1081,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
wo_block_work_id, wo_block_work_id,
wo_thread_id, wo_thread_id,
0)) 0))
.Run(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc, .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf, d_thread_buf,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
d_global_buf, d_global_buf,
d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks); d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks);
} }
} }
}; };
......
...@@ -347,7 +347,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0 ...@@ -347,7 +347,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc); GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(d_k_n_hopx2_wopx2_grid_desc); GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptor(d_k_n_hopx2_wopx2_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc); using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment