Commit 1043ab4f authored by ltqin's avatar ltqin
Browse files

trans a and b to gridegemm

parent a088771c
......@@ -18,6 +18,8 @@ template <typename GridwiseGemm,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor>
__global__ void
......@@ -29,6 +31,8 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const AK0MK1GridDesc a_k0_m_k1_grid_desc,
const BK0NK1GridDesc b_k0_n_k1_grid_desc,
const void CONSTANT* a_b_k0_m_k1_grid_desc,
const void CONSTANT* b_b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc,
const CBlockClusterAdaptor c_block_cluster_adaptor)
{
......@@ -43,6 +47,8 @@ __global__ void
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
......@@ -52,6 +58,8 @@ template <typename GridwiseGemm,
typename FloatC,
typename AK0MK1GridDesc,
typename BK0NK1GridDesc,
typename ABK0MK1GridDesc,
typename BBK0NK1GridDesc,
typename CM0N0M1N1M2M3M4N2GridDesc,
typename CBlockClusterAdaptor>
__global__ void
......@@ -63,6 +71,8 @@ __global__ void
FloatC* __restrict__ p_c_grid,
const void CONSTANT* p_a_k0_m_k1_grid_desc,
const void CONSTANT* p_b_k0_n_k1_grid_desc,
const void CONSTANT* p_a_b_k0_m_k1_grid_desc,
const void CONSTANT* p_b_b_k0_n_k1_grid_desc,
const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const void CONSTANT* p_c_block_cluster_adaptor)
{
......@@ -73,6 +83,10 @@ __global__ void
cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc));
const auto b_k0_n_k1_grid_desc = *reinterpret_cast<const BK0NK1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc));
const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast<const ABK0MK1GridDesc*>(
cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc));
const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast<const BBK0NK1GridDesc*>(
cast_pointer_to_generic_address_space(p_b_b_k0_n_k1_grid_desc));
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
*reinterpret_cast<const CM0N0M1N1M2M3M4N2GridDesc*>(
cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc));
......@@ -87,6 +101,8 @@ __global__ void
p_shared_block,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
}
......@@ -311,6 +327,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
return c_blockid_to_m0_n0_block_cluster_adaptor;
}
using ABK0MK1GridDesc = decltype(MakeABK0MK1GridDescriptor(AK0MK1GridDesc{}, I1));
using BBK0NK1GridDesc = decltype(MakeBBK0NK1GridDescriptor(BK0NK1GridDesc{}, I1));
using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{}));
using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}));
......@@ -320,6 +338,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
FloatAB* __restrict__ p_shared_block,
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc,
const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc,
const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
const CBlockClusterAdaptor& c_block_cluster_adaptor)
{
......@@ -330,12 +350,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r4
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize());
const auto kbatch = CalculateKBatch(CMNGridDesc{}, b_k0_n_k1_grid_desc);
if(get_block_1d_id() == 0)
printf("*****kbatch : %d", kbatch);
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
const auto kbatch = CalculateKBatch(CMNGridDesc{}, b_k0_n_k1_grid_desc);
if(get_block_1d_id() == 0)
printf("*****kbatch : %d, %d, %d, %d\n",
kbatch,
a_b_k0_m_k1_grid_desc.GetLength(I0),
b_b_k0_n_k1_grid_desc.GetLength(I0),
K0);
// divide block work by [M, N]
const auto block_work_idx =
c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......
......@@ -123,10 +123,10 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
}
const auto kbatch = GridwiseGemm::CalculateKBatch(c_m_n_grid_desc, b_k0_n_k1_grid_desc);
// const auto a_b_k0_m_k1_grid_desc =
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, kbatch);
// const auto b_b_k0_n_k1_grid_desc =
GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc, kbatch);
const auto a_b_k0_m_k1_grid_desc =
GridwiseGemm::MakeABK0MK1GridDescriptor(a_k0_m_k1_grid_desc, kbatch);
const auto b_b_k0_n_k1_grid_desc =
GridwiseGemm::MakeBBK0NK1GridDescriptor(b_k0_n_k1_grid_desc, kbatch);
{
std::cout << "k batch number is: " << kbatch << std::endl;
}
......@@ -139,8 +139,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc =
GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc);
// using ABK0MK1GridDesc = decltype(a_b_k0_m_k1_grid_desc);
// using BBK0NK1GridDesc = decltype(b_b_k0_n_k1_grid_desc);
using ABK0MK1GridDesc = decltype(a_b_k0_m_k1_grid_desc);
using BBK0NK1GridDesc = decltype(b_b_k0_n_k1_grid_desc);
using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc);
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
......@@ -158,6 +158,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
FloatC,
remove_reference_t<AK0MK1GridDesc>,
remove_reference_t<BK0NK1GridDesc>,
remove_reference_t<ABK0MK1GridDesc>,
remove_reference_t<BBK0NK1GridDesc>,
remove_reference_t<CM0N0M1N1M2M3M4N2GridDesc>,
remove_reference_t<CBlockClusterAdaptor>>;
......@@ -172,12 +174,16 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
p_c_grid,
a_k0_m_k1_grid_desc,
b_k0_n_k1_grid_desc,
a_b_k0_m_k1_grid_desc,
b_b_k0_n_k1_grid_desc,
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
c_block_cluster_adaptor);
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc));
DeviceMem b_b_k0_n_k1_grid_desc_dev_buf(sizeof(BBK0NK1GridDesc));
DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc));
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
......@@ -197,6 +203,8 @@ __host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid,
p_c_grid,
cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()),
cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment