Commit a52e5a92 authored by ltqin's avatar ltqin
Browse files

finish driver_gemm_xdlops file

parent a3b31a92
...@@ -158,7 +158,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 ...@@ -158,7 +158,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc); return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc);
} }
template <typename CGMNGridDesc> template <typename CGMNGridDesc>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGM0N0M1N1M2M3M4N2GridDescriptor(const CGMNGridDesc& c_g_m_n_grid_desc) MakeCGM0N0M1N1M2M3M4N2GridDescriptor(const CGMNGridDesc& c_g_m_n_grid_desc)
{ {
......
...@@ -246,22 +246,23 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1 ...@@ -246,22 +246,23 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
const auto N0 = N / N1; const auto N0 = N / N1;
#if 1 #if 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto c_blockid_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(G, M0, N0))), make_tuple(make_merge_transform(make_tuple(G, M0, N0))),
make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
#elif 1 #elif 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor = const auto c_blockid_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(G, N0, M0))), make_tuple(make_merge_transform(make_tuple(G, N0, M0))),
make_tuple(Sequence<0, 2, 1>{}), make_tuple(Sequence<0, 2, 1>{}),
make_tuple(Sequence<0>{})); make_tuple(Sequence<0>{}));
#endif #endif
return c_blockid_to_m0_n0_block_cluster_adaptor; return c_blockid_to_m0_n0_block_cluster_adaptor;
} }
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCGM0N0M1N1M2M3M4N2GridDescriptor(CGMNGridDesc{})); using CM0N0M1N1M2M3M4N2GridDesc =
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGMNGridDesc{})); decltype(MakeCGM0N0M1N1M2M3M4N2GridDescriptor(CGMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
......
...@@ -70,46 +70,46 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid, ...@@ -70,46 +70,46 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
using GridwiseGemm = using GridwiseGemm =
GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1<BlockSize, GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1<BlockSize,
FloatAB, FloatAB,
FloatAcc, FloatAcc,
FloatC, FloatC,
CGlobalMemoryDataOperation, CGlobalMemoryDataOperation,
AGK0MK1GridDesc, AGK0MK1GridDesc,
BGK0NK1GridDesc, BGK0NK1GridDesc,
CGMNGridDesc, CGMNGridDesc,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
K1, K1,
MRepeat, MRepeat,
NRepeat, NRepeat,
ABlockTransferThreadSliceLengths_G_K0_M_K1, ABlockTransferThreadSliceLengths_G_K0_M_K1,
ABlockTransferThreadClusterLengths_G_K0_M_K1, ABlockTransferThreadClusterLengths_G_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1, ABlockTransferDstScalarPerVector_K1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferThreadSliceLengths_G_K0_N_K1, BBlockTransferThreadSliceLengths_G_K0_N_K1,
BBlockTransferThreadClusterLengths_G_K0_N_K1, BBlockTransferThreadClusterLengths_G_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1, BBlockTransferDstScalarPerVector_K1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
AGridStepHacks, AGridStepHacks,
BGridStepHacks, BGridStepHacks,
CGridStepHacks, CGridStepHacks,
AGridMoveSliceWindowStepHacks, AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks, BGridMoveSliceWindowStepHacks,
CAccessOrderMRepeatNRepeat>; CAccessOrderMRepeatNRepeat>;
{ {
std::cout << "a_g_k0_m_k1_grid_desc{" << a_g_k0_m_k1_grid_desc.GetLength(I0) << ", " std::cout << "a_g_k0_m_k1_grid_desc{" << a_g_k0_m_k1_grid_desc.GetLength(I0) << ", "
...@@ -134,66 +134,65 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid, ...@@ -134,66 +134,65 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
} }
const auto c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = const auto c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCGM0N0M1N1M2M3M4N2GridDescriptor(c_g_m_n_grid_desc); GridwiseGemm::MakeCGM0N0M1N1M2M3M4N2GridDescriptor(c_g_m_n_grid_desc);
using CGM0N0M1N1M2M3M4N2GridDesc = decltype(c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); using CGM0N0M1N1M2M3M4N2GridDesc = decltype(c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_g_m_n_grid_desc); const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_g_m_n_grid_desc);
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_g_m_n_grid_desc); const index_t grid_size = GridwiseGemm::CalculateGridSize(c_g_m_n_grid_desc);
const auto kernel = kernel_gemm_xdlops_v3r1<GridwiseGemm, const auto kernel = kernel_gemm_xdlops_v3r1<GridwiseGemm,
FloatAB, FloatAB,
FloatC, FloatC,
remove_reference_t<AGK0MK1GridDesc>, remove_reference_t<AGK0MK1GridDesc>,
remove_reference_t<BGK0NK1GridDesc>, remove_reference_t<BGK0NK1GridDesc>,
remove_reference_t<CGM0N0M1N1M2M3M4N2GridDesc>, remove_reference_t<CGM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>; remove_reference_t<CBlockClusterAdaptor>>;
/* #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel, float ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
a_k0_m_k1_grid_desc, a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc, b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor); c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem a_g_k0_m_k1_grid_desc_dev_buf(sizeof(AGK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); DeviceMem b_g_k0_n_k1_grid_desc_dev_buf(sizeof(BGK0NK1GridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); DeviceMem c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CGM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); a_g_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_g_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); b_g_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_g_k0_n_k1_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel( float ave_time = launch_and_time_kernel(
kernel, kernel,
nrepeat, nrepeat,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(a_g_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_g_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space( cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif #endif
return ave_time;*/ return ave_time;
return 0.0;
} }
#endif #endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment