Commit ac0114f8 authored by turneram's avatar turneram
Browse files

Hard-code kernel params

parent 393caa33
...@@ -109,37 +109,70 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) ...@@ -109,37 +109,70 @@ __global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p)
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
auto gemm = DeviceGemmInstance{}; using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1));
auto invoker = gemm.MakeInvoker(); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1));
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_p), using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
static_cast<BDataType*>(b_p),
static_cast<CDataType*>(c_p), // GridwiseGemm
M, using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1<
N, ADataType, // TODO: distinguish A/B datatype
K, AccDataType,
StrideA, CShuffleDataType,
StrideB, CDataType,
StrideC, AElementOp,
a_element_op, BElementOp,
b_element_op, CElementOp,
c_element_op); ck::InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
// make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) { BGridDesc_BK0_N_BK1,
// auto gemm = DeviceGemmInstance{}; CGridDesc_M_N,
// auto invoker = gemm.MakeInvoker(); NumGemmKPrefetchStage,
// auto argument = gemm.MakeArgument(static_cast<ADataType*>(a), BlockSize,
// static_cast<BDataType*>(b), MPerBlock,
// static_cast<CDataType*>(c), NPerBlock,
// M, KPerBlock,
// N, AK1,
// K, BK1,
// StrideA, MPerXDL,
// StrideB, NPerXDL,
// StrideC, MXdlPerWave,
// a_element_op, NXdlPerWave,
// b_element_op, ABlockTransferThreadClusterLengths_AK0_M_AK1,
// c_element_op); ABlockTransferThreadClusterArrangeOrder,
// }); ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopSched>;
const auto kernel = kernel_gemm_xdlops_v2r3<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementOp,
BElementOp,
CElementOp,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
true>;
kernel<<<1, 1, 1, 0>>>(p_a, p_b, p_c);
} }
} }
...@@ -162,6 +195,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -162,6 +195,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.kernel_name = "ck_gemm_kernel"; options.kernel_name = "ck_gemm_kernel";
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
return compile_hip_code_object(ck_gemm_kernel, options); return compile_hip_code_object(ck_gemm_kernel, options);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment