Commit 0b6b490e authored by Alan Turner's avatar Alan Turner
Browse files

Formatting

parent c96139f8
...@@ -76,7 +76,7 @@ MIGRAPHX_REGISTER_OP(ck_gemm); ...@@ -76,7 +76,7 @@ MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_softmax_gemm struct ck_gemm_softmax_gemm
{ {
operation op = make_op("dot"); operation op = make_op("dot");
float scale = 1.0; float scale = 1.0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
......
...@@ -167,8 +167,8 @@ inline bool can_fold_batch(const std::vector<shape>& inputs) ...@@ -167,8 +167,8 @@ inline bool can_fold_batch(const std::vector<shape>& inputs)
{ {
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) { if(std::any_of(inputs.begin() + 2, inputs.end() - 1, [](auto input) {
return not standard_batch(input); return not standard_batch(input);
})) }))
return false; return false;
const auto& b_strides = b_shape.strides(); const auto& b_strides = b_shape.strides();
return std::all_of( return std::all_of(
......
...@@ -82,7 +82,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -82,7 +82,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto& a_shape = inputs[0]; const auto& a_shape = inputs[0];
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto rank = a_shape.ndim(); auto rank = a_shape.ndim();
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2]; auto m = c_shape.lens()[rank - 2];
......
...@@ -89,7 +89,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -89,7 +89,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
const auto& b1_shape = inputs[2]; const auto& b1_shape = inputs[2];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto rank = a_shape.ndim(); auto rank = a_shape.ndim();
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2]; auto m = c_shape.lens()[rank - 2];
...@@ -173,7 +173,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler> ...@@ -173,7 +173,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{"blocks_per_batch", to_string(blocks_per_batch)}, {"blocks_per_batch", to_string(blocks_per_batch)},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}}); {"kernel", options.kernel_name}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment