Commit a52e5a92 authored by ltqin's avatar ltqin
Browse files

finish driver_gemm_xdlops file

parent a3b31a92
......@@ -158,7 +158,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
}
template <typename CGMNGridDesc>
template <typename CGMNGridDesc>
__host__ __device__ static constexpr auto
MakeCGM0N0M1N1M2M3M4N2GridDescriptor(const CGMNGridDesc& c_g_m_n_grid_desc)
{
......
......@@ -246,22 +246,23 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
const auto N0 = N / N1;
#if 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(G, M0, N0))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(G, M0, N0))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
#elif 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(G, N0, M0))),
make_tuple(Sequence<0, 2, 1>{}),
make_tuple(Sequence<0>{}));
const auto c_blockid_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(G, N0, M0))),
make_tuple(Sequence<0, 2, 1>{}),
make_tuple(Sequence<0>{}));
#endif
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCGM0N0M1N1M2M3M4N2GridDescriptor(CGMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGMNGridDesc{}));
using CM0N0M1N1M2M3M4N2GridDesc =
decltype(MakeCGM0N0M1N1M2M3M4N2GridDescriptor(CGMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
......
......@@ -70,46 +70,46 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
using GridwiseGemm =
GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1<BlockSize,
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGK0MK1GridDesc,
BGK0NK1GridDesc,
CGMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_G_K0_M_K1,
ABlockTransferThreadClusterLengths_G_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_G_K0_N_K1,
BBlockTransferThreadClusterLengths_G_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
CAccessOrderMRepeatNRepeat>;
FloatAB,
FloatAcc,
FloatC,
CGlobalMemoryDataOperation,
AGK0MK1GridDesc,
BGK0NK1GridDesc,
CGMNGridDesc,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
K1,
MRepeat,
NRepeat,
ABlockTransferThreadSliceLengths_G_K0_M_K1,
ABlockTransferThreadClusterLengths_G_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_G_K0_N_K1,
BBlockTransferThreadClusterLengths_G_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
AGridStepHacks,
BGridStepHacks,
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks,
CAccessOrderMRepeatNRepeat>;
{
std::cout << "a_g_k0_m_k1_grid_desc{" << a_g_k0_m_k1_grid_desc.GetLength(I0) << ", "
......@@ -134,66 +134,65 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
}
const auto c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCGM0N0M1N1M2M3M4N2GridDescriptor(c_g_m_n_grid_desc);
const auto c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCGM0N0M1N1M2M3M4N2GridDescriptor(c_g_m_n_grid_desc);
using CGM0N0M1N1M2M3M4N2GridDesc = decltype(c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
using CGM0N0M1N1M2M3M4N2GridDesc = decltype(c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_g_m_n_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_g_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_g_m_n_grid_desc);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_g_m_n_grid_desc);
const auto kernel = kernel_gemm_xdlops_v3r1<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGK0MK1GridDesc>,
remove_reference_t<BGK0NK1GridDesc>,
remove_reference_t<CGM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
/* #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif
return ave_time;*/
return 0.0;
const auto kernel = kernel_gemm_xdlops_v3r1<GridwiseGemm,
FloatAB,
FloatC,
remove_reference_t<AGK0MK1GridDesc>,
remove_reference_t<BGK0NK1GridDesc>,
remove_reference_t<CGM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_g_k0_m_k1_grid_desc_dev_buf(sizeof(AGK0MK1GridDesc));
DeviceMem b_g_k0_n_k1_grid_desc_dev_buf(sizeof(BGK0NK1GridDesc));
DeviceMem c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CGM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_g_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_g_k0_m_k1_grid_desc);
b_g_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_g_k0_n_k1_grid_desc);
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel(
kernel,
nrepeat,
dim3(grid_size),
dim3(BlockSize),
0,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(a_g_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_g_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif
return ave_time;
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment