Commit 3ec069ec authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent 6a825932
...@@ -38,7 +38,6 @@ ...@@ -38,7 +38,6 @@
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
...@@ -64,7 +63,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING); ...@@ -64,7 +63,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
// NOLINTNEXTLINE // NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__( static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp> #include <args.hpp>
...@@ -291,9 +289,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -291,9 +289,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{ {
auto a_shape = inputs[0]; auto a_shape = inputs[0];
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto c_shape = inputs.back(); auto c_shape = inputs.back();
auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape}); auto tuning_value = get_tuning_for({a_shape, b_shape, c_shape});
auto rank = a_shape.lens().size(); auto rank = a_shape.lens().size();
...@@ -307,15 +305,17 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -307,15 +305,17 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto k = a_shape.lens().back(); auto k = a_shape.lens().back();
const auto numDTensors = inputs.size() - 3; const auto numDTensors = inputs.size() - 3;
const bool transA = transposed_matrix(a_shape); const bool transA = transposed_matrix(a_shape);
const bool transB = transposed_matrix(b_shape); const bool transB = transposed_matrix(b_shape);
const bool transCDE = transposed_matrix(c_shape); const bool transCDE = transposed_matrix(c_shape);
const auto a_type = get_type(a_shape); const auto a_type = get_type(a_shape);
const auto b_type = get_type(b_shape); const auto b_type = get_type(b_shape);
const auto cde_type = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type);//get_type(c_shape); const auto cde_type =
ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type); // get_type(c_shape);
const auto cde_layout = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout); const auto cde_layout = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout);
std::string ck_passthrough = "ck_passthrough";//"ck::tensor_operation::element_wise::PassThrough"; std::string ck_passthrough =
"ck_passthrough"; //"ck::tensor_operation::element_wise::PassThrough";
std::string cde_op = ck_passthrough; std::string cde_op = ck_passthrough;
assert(inputs.size() < 4 or v.contains("post")); assert(inputs.size() < 4 or v.contains("post"));
if(v.contains("post")) if(v.contains("post"))
...@@ -323,16 +323,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -323,16 +323,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
cde_op = v.at("post").to<std::string>(); cde_op = v.at("post").to<std::string>();
} }
auto problem = ck::tensor_operation::device::instance::Problem{static_cast<ck::index_t>(m), static_cast<ck::index_t>(n), static_cast<ck::index_t>(k), static_cast<ck::index_t>(numDTensors), static_cast<ck::index_t>(tuning_value), auto problem =
transA, transB, transCDE, ck::tensor_operation::device::instance::Problem{static_cast<ck::index_t>(m),
a_type, b_type, cde_type, static_cast<ck::index_t>(n),
ck_passthrough, ck_passthrough, cde_op, cde_layout}; static_cast<ck::index_t>(k),
const auto solution = problem.GetSolution(); static_cast<ck::index_t>(numDTensors),
static_cast<ck::index_t>(tuning_value),
transA,
transB,
transCDE,
a_type,
b_type,
cde_type,
ck_passthrough,
ck_passthrough,
cde_op,
cde_layout};
const auto solution = problem.GetSolution();
auto blocks_per_batch = problem.GetGridSize(); auto blocks_per_batch = problem.GetGridSize();
auto block_size = problem.GetBlockSize(); auto block_size = problem.GetBlockSize();
hip_compile_options options; hip_compile_options options;
auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch; auto grid_size = can_fold_batch ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
options.output = c_shape; options.output = c_shape;
...@@ -349,7 +361,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -349,7 +361,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{})) if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
options.params += " -DMIGRAPHX_CK_CHECK=1"; options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_kernel, auto src = interpolate_string(ck_gemm_kernel,
{{"solution", solution}, {{"solution", solution},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment