Commit c34561b3 authored by Alan Turner's avatar Alan Turner
Browse files

Use new interface

parent 3ec069ec
......@@ -328,7 +328,6 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
static_cast<ck::index_t>(n),
static_cast<ck::index_t>(k),
static_cast<ck::index_t>(numDTensors),
static_cast<ck::index_t>(tuning_value),
transA,
transB,
transCDE,
......@@ -339,9 +338,11 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ck_passthrough,
cde_op,
cde_layout};
const auto solution = problem.GetSolution();
auto blocks_per_batch = problem.GetGridSize();
auto block_size = problem.GetBlockSize();
const auto solutions = problem.GetSolutions();
const auto solution = solutions.at(tuning_value);
const auto template_str = solution.GetStr();
const auto blocks_per_batch = solution.GetGridSize();
const auto block_size = solution.GetBlockSize();
hip_compile_options options;
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
......@@ -363,7 +364,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_kernel,
{{"solution", solution},
{{"solution", template_str},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)},
......
......@@ -49,7 +49,7 @@ template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
{
constexpr auto desc = G::make_descriptor(to_ck_tensor<A>(),
to_ck_tensor<B>(),
to_ck_tensor<ck_transposeb<B>>(),
ck::make_tuple(to_ck_tensor<Ds>()...),
to_ck_tensor<E>());
G::Run(desc,
......@@ -67,4 +67,4 @@ __device__ void ck_gemm(Ts... xs)
}
} // namespace migraphx
#endif
\ No newline at end of file
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment