"vscode:/vscode.git/clone" did not exist on "0bd6b842b96d052e03b4726ad63f8d337550cf1f"
Commit ab0b8b80 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix cpp check errors.

parent 702eeb45
...@@ -33,7 +33,7 @@ struct miopen_quant_convolution ...@@ -33,7 +33,7 @@ struct miopen_quant_convolution
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
shape compile(context& ctx, const shape& output_shape, std::vector<shape> inputs); shape compile(context& ctx, const shape& output_shape, std::vector<shape> inputs);
void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs); void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs);
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; } std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
private: private:
shape pack_int8_shape(shape& s); shape pack_int8_shape(shape& s);
......
...@@ -13,8 +13,8 @@ struct context; ...@@ -13,8 +13,8 @@ struct context;
struct miopen_quant_gemm struct miopen_quant_gemm
{ {
op::quant_dot op; op::quant_dot op;
mutable argument arg_a{};
mutable argument arg_b{}; miopen_quant_gemm(op::quant_dot qop) : op(qop) {}
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -30,6 +30,10 @@ struct miopen_quant_gemm ...@@ -30,6 +30,10 @@ struct miopen_quant_gemm
{ {
return shapes.size() - 1; return shapes.size() - 1;
} }
private:
mutable argument arg_a;
mutable argument arg_b;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -57,6 +57,25 @@ shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const ...@@ -57,6 +57,25 @@ shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
std::vector<shape> input_shapes(inputs); std::vector<shape> input_shapes(inputs);
input_shapes.pop_back(); input_shapes.pop_back();
check_shapes{input_shapes}.not_broadcasted(); check_shapes{input_shapes}.not_broadcasted();
bool transa = inputs[0].transposed();
bool transb = inputs[1].transposed();
if(!transb)
{
if(arg_b.empty())
{
arg_b = allocate_gpu(inputs[1]);
}
}
if(transa)
{
if(arg_a.empty())
{
arg_a = allocate_gpu(inputs[0]);
}
}
return op.compute_shape(input_shapes); return op.compute_shape(input_shapes);
} }
...@@ -75,10 +94,6 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -75,10 +94,6 @@ argument miopen_quant_gemm::compute(context& ctx,
if(!transb) if(!transb)
{ {
if(arg_b.empty())
{
arg_b = allocate_gpu(args[1].get_shape());
}
device::pack_a(ctx.get_stream().get(), arg_b, args[1]); device::pack_a(ctx.get_stream().get(), arg_b, args[1]);
} }
...@@ -86,10 +101,6 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -86,10 +101,6 @@ argument miopen_quant_gemm::compute(context& ctx,
// comment of the API // comment of the API
if(transa) if(transa)
{ {
if(arg_a.empty())
{
arg_a = allocate_gpu(args.at(0).get_shape());
}
device::pack_b(ctx.get_stream().get(), arg_a, args[0]); device::pack_b(ctx.get_stream().get(), arg_a, args[0]);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment