Commit 32addf31 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

clang format

parent 8e824ed1
...@@ -179,16 +179,16 @@ struct miopen_apply ...@@ -179,16 +179,16 @@ struct miopen_apply
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
// add additional arguments if need packing. Since lowering is added // add additional arguments if need packing. Since lowering is added
// after auto_contiguous and before eliminate contiguous, the shapes // after auto_contiguous and before eliminate contiguous, the shapes
// of all inputs are standard, so the input shape cannot be transposed. // of all inputs are standard, so the input shape cannot be transposed.
// To avoid that, we need to check whether this argument is an output // To avoid that, we need to check whether this argument is an output
// of contiguous. If true, we should check the shape of the input // of contiguous. If true, we should check the shape of the input
// of the contiguous operator. // of the contiguous operator.
auto prev_ins = refs.at(0); auto prev_ins = refs.at(0);
if (prev_ins->name() == "gpu::contiguous") if(prev_ins->name() == "gpu::contiguous")
{ {
auto input = prev_ins->inputs().front(); auto input = prev_ins->inputs().front();
if (input->get_shape().transposed()) if(input->get_shape().transposed())
{ {
auto pack_a = insert_allocation(input, input->get_shape()); auto pack_a = insert_allocation(input, input->get_shape());
// replace one of the inputs of quant_gemm from the output to the // replace one of the inputs of quant_gemm from the output to the
......
...@@ -75,7 +75,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -75,7 +75,7 @@ argument miopen_quant_gemm::compute(context& ctx,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
// handling the packing of B MUST be before handling that for A // handling the packing of B MUST be before handling that for A
auto arg_res = args.back(); auto arg_res = args.back();
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment