Commit 0f276ac2 authored by Jing Zhang's avatar Jing Zhang
Browse files

add configurable makeddesc

parent 35a57947
......@@ -18,7 +18,7 @@ template <typename GridwiseGemm,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__global__ void
......@@ -34,7 +34,7 @@ __global__ void
const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W c_blockid_to_k_n_h_w_block_cluster_adaptor)
{
constexpr index_t shared_block_size =
......@@ -51,7 +51,7 @@ __global__ void
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{});
}
......@@ -65,7 +65,7 @@ template <typename GridwiseGemm,
typename AGridDesc_E0_E1_K0_K1_E2,
typename BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2,
typename CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2,
typename DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2,
typename DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx,
typename CBlockIdToBlockClusterAdaptor_K_N_H_W,
bool HasMainE0BlockLoop>
__global__ void
......@@ -80,7 +80,7 @@ __global__ void
const void CONSTANT* p_a_e0_e1_k0_k1_e2_grid_desc,
const void CONSTANT* p_b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const void CONSTANT* p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const void CONSTANT* p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
const void CONSTANT* p_d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const void CONSTANT* p_c_blockid_to_k_n_h_w_block_cluster_adaptor)
{
// first cast void CONSTANT void* to void*
......@@ -94,9 +94,9 @@ __global__ void
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
*reinterpret_cast<const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2*>(
cast_pointer_to_generic_address_space(p_c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc));
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
*reinterpret_cast<const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2*>(
cast_pointer_to_generic_address_space(p_d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc));
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc =
*reinterpret_cast<const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx*>(
cast_pointer_to_generic_address_space(p_d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc));
const auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
*reinterpret_cast<const CBlockIdToBlockClusterAdaptor_K_N_H_W*>(
cast_pointer_to_generic_address_space(p_c_blockid_to_k_n_h_w_block_cluster_adaptor));
......@@ -115,7 +115,7 @@ __global__ void
a_e0_e1_k0_k1_e2_grid_desc,
b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>{});
}
......@@ -129,7 +129,7 @@ template <index_t BlockSize,
typename AGridDesc_E0_E1_K_E2,
typename BGridDesc_E0_E1_N_Ho_Wo_E2,
typename CGridDesc_K_N_Ho_Wo,
typename DGridDesc_K_N_Hox2_Wox2,
typename DGridDesc_K_N_Hx_Wx,
index_t E1_,
index_t E2_,
index_t K2_,
......@@ -162,7 +162,8 @@ template <index_t BlockSize,
typename DGlobalStepHacks,
typename AGlobalMoveSliceWindowStepHacks,
typename BGlobalMoveSliceWindowStepHacks,
index_t activ_type = 0>
index_t activ_type = 0,
index_t add_type = 0>
struct GridwiseGemmDlops_km_kn_mn_v3_add
{
static constexpr auto I0 = Number<0>{};
......@@ -327,27 +328,58 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc;
}
__host__ __device__ static constexpr auto MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(
const DGridDesc_K_N_Hox2_Wox2& d_k_n_hox2_wox2_grid_desc)
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0);
const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1);
const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2);
const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto H2 = HoPerThread / 2;
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Hx / (H1 * H2);
const auto W2 = WoPerThread / 2;
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wx / (W1 * W2);
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor(
d_k_n_hx_wx_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
make_unmerge_transform(make_tuple(W0, W1, W2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc;
}
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
const auto K = d_k_n_hox2_wox2_grid_desc.GetLength(I0);
const auto N = d_k_n_hox2_wox2_grid_desc.GetLength(I1);
const auto Hox2 = d_k_n_hox2_wox2_grid_desc.GetLength(I2);
const auto Wox2 = d_k_n_hox2_wox2_grid_desc.GetLength(I3);
const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0);
const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1);
const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2);
const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3);
const auto K1 = Number<KPerBlock>{};
const auto K0 = K / K1;
const auto H2 = HoPerThread * 2;
const auto H1 = Number<HoPerBlock / HoPerThread>{};
const auto H0 = Hox2 / (H1 * H2);
const auto H0 = Hx / (H1 * H2);
const auto W2 = WoPerThread * 2;
const auto W1 = Number<WoPerBlock / WoPerThread>{};
const auto W0 = Wox2 / (W1 * W2);
const auto W0 = Wx / (W1 * W2);
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc = transform_tensor_descriptor(
d_k_n_hox2_wox2_grid_desc,
const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor(
d_k_n_hx_wx_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(K0, K1)),
make_pass_through_transform(N),
make_unmerge_transform(make_tuple(H0, H1, H2)),
......@@ -355,7 +387,24 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{}));
return d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc;
return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc;
}
__host__ __device__ static constexpr auto
MakeDK0K1NH0H1HxW0W1WxGridDescriptor(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc)
{
if constexpr(add_type == 0)
{
return MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(d_k_n_hx_wx_grid_desc);
}
else if constexpr(add_type == 1)
{
return MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(d_k_n_hx_wx_grid_desc);
}
else
{
return MakeCK0K1NH0H1H2W0W1W2GridDescriptor(d_k_n_hx_wx_grid_desc);
}
}
__host__ __device__ static constexpr auto
......@@ -385,17 +434,17 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{}));
using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 =
decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{}));
using DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2 =
decltype(MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(DGridDesc_K_N_Hox2_Wox2{}));
using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx =
decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{}));
using CBlockIdToBlockClusterAdaptor_K_N_H_W =
decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{}));
__host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor(
const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2& d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc)
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc)
{
const auto K0 = d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.GetLength(I0);
const auto K1 = d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.GetLength(I1);
const auto K0 = d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetLength(I0);
const auto K1 = d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetLength(I1);
return make_naive_tensor_descriptor_packed(make_tuple(K0, K1));
}
......@@ -411,7 +460,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc,
const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc,
const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2& d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
const CBlockIdToBlockClusterAdaptor_K_N_H_W& c_blockid_to_k_n_h_w_block_cluster_adaptor,
integral_constant<bool, HasMainE0BlockLoop>)
{
......@@ -419,13 +468,13 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
// constexpr auto a_e0_e1_k0_k1_e2_grid_desc = AGridDesc_E0_E1_K0_K1_E2{};
// constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc =
// BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2{};
// constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
// DGridDesc_K0_K1_N_H0_H1_H2x2_W0_W1_W2x2{};
// constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc =
// DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx{};
// constexpr auto c_blockid_to_k_n_h_w_block_cluster_adaptor =
// CBlockIdToBlockClusterAdaptor_K_N_H_W{};
const auto bias_k0_k1_grid_desc =
MakeBiasK0K1GridDescriptor(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc);
MakeBiasK0K1GridDescriptor(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc);
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize());
......@@ -434,7 +483,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize());
auto d_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_d_global, d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc.GetElementSpaceSize());
p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize());
auto bias_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize());
......@@ -933,7 +982,7 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
constexpr auto WoPerThreadx2 = WoPerThread * 2;
#if 1
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc =
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(I1,
Number<KPerThread>{},
I1,
......@@ -946,15 +995,14 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
StaticBuffer<AddressSpaceEnum_t::Vgpr,
FloatC,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc.GetElementSpaceSize(),
d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.GetElementSpaceSize(),
true>
d_thread_buf;
static_for<0, KPerThread, 1>{}([&](auto k_i) {
static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) {
static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) {
d_thread_buf(
Number<d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc.CalculateOffset(
d_thread_buf(Number<d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset(
make_tuple(0, k_i, 0, 0, 0, h_i, 0, 0, w_i))>{}) =
c_thread_buf[Number<c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(
make_tuple(k_i, 0, h_i / 2, w_i / 2))>{}];
......@@ -974,16 +1022,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
I1,
Number<WoPerThread>{}));
constexpr auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc =
transform_tensor_descriptor(
constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = transform_tensor_descriptor(
c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_desc,
make_tuple(make_pass_through_transform(I1),
make_tuple(
make_pass_through_transform(I1),
make_pass_through_transform(Number<KPerThread>{}),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_embed_transform(make_tuple(I2, Number<HoPerThread>{}),
make_tuple(I0, I1)),
make_embed_transform(make_tuple(I2, Number<HoPerThread>{}), make_tuple(I0, I1)),
make_pass_through_transform(I1),
make_pass_through_transform(I1),
make_embed_transform(make_tuple(I2, Number<WoPerThread>{}),
......@@ -1009,23 +1056,22 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
#endif
// hack to control index calculation when iterating over d_k_n_ho_wo_global tensor
constexpr auto d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks =
DGlobalStepHacks{};
constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{};
const index_t k_thread_data_on_global = k_thread_id * KPerThread;
ThreadwiseTensorSliceTransfer_v1r3<
FloatC,
FloatC,
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc),
decltype(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc),
decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc),
Sequence<I1, KPerThread, I1, I1, I1, HoPerThreadx2, I1, I1, WoPerThreadx2>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation,
1,
true>(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
make_multi_index(k_block_work_id,
k_thread_data_on_global,
n_block_work_id,
......@@ -1035,12 +1081,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3_add
wo_block_work_id,
wo_thread_id,
0))
.Run(d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_thread_desc,
.Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
d_thread_buf,
d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc,
d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc,
d_global_buf,
d_k_n_h0_h1_h2x2_w0_w1_w2x2_global_tensor_step_hacks);
d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks);
}
}
};
......
......@@ -347,7 +347,7 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nc0hwc1_kc0yxc1_nk0
const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc =
GridwiseGemm::MakeCK0K1NH0H1H2W0W1W2GridDescriptor(c_k_n_hop_wop_grid_desc);
const auto d_k0_k1_n_h0_h1_h2x2_w0_w1_w2x2_grid_desc =
GridwiseGemm::MakeDK0K1NH0H1H2x2W0W1W2x2GridDescriptor(d_k_n_hopx2_wopx2_grid_desc);
GridwiseGemm::MakeDK0K1NH0H1HxW0W1WxGridDescriptor(d_k_n_hopx2_wopx2_grid_desc);
using AGridDesc_E0_E1_K0_K1_E2 = decltype(a_e0_e1_k0_k1_e2_grid_desc);
using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment