Commit a52e5a92 authored by ltqin's avatar ltqin
Browse files

finish driver_gemm_xdlops file

parent a3b31a92
......@@ -246,13 +246,13 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
const auto N0 = N / N1;
#if 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(G, M0, N0))),
const auto c_blockid_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(G, M0, N0))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
#elif 1
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(G, N0, M0))),
const auto c_blockid_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(G, N0, M0))),
make_tuple(Sequence<0, 2, 1>{}),
make_tuple(Sequence<0>{}));
#endif
......@@ -260,7 +260,8 @@ struct GridwiseGemm_gk0mk1_gk0nk1_gmn_xdlops_v3r1
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCGM0N0M1N1M2M3M4N2GridDescriptor(CGMNGridDesc{}));
using CM0N0M1N1M2M3M4N2GridDesc =
decltype(MakeCGM0N0M1N1M2M3M4N2GridDescriptor(CGMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGMNGridDesc{}));
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
......
......@@ -153,7 +153,7 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
remove_reference_t<CGM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
/* #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(grid_size),
......@@ -167,15 +167,15 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_g_k0_m_k1_grid_desc_dev_buf(sizeof(AGK0MK1GridDesc));
DeviceMem b_g_k0_n_k1_grid_desc_dev_buf(sizeof(BGK0NK1GridDesc));
DeviceMem c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CGM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
a_g_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_g_k0_m_k1_grid_desc);
b_g_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_g_k0_n_k1_grid_desc);
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
float ave_time = launch_and_time_kernel(
......@@ -187,13 +187,12 @@ __host__ float driver_gemm_xdlops_v3r1(const FloatAB* p_a_grid,
p_a_grid,
p_b_grid,
p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(a_g_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_g_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
c_g_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
#endif
return ave_time;*/
return 0.0;
#endif
return ave_time;
}
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment