"include/ck/utility/container_helper.hpp" did not exist on "fcbb978828b308d8c367a3eeaebee485a61b548c"
Commit ef9d113a authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Rename gemm to dot...

parent ceaf5ee0
...@@ -582,7 +582,7 @@ struct reshape ...@@ -582,7 +582,7 @@ struct reshape
} }
}; };
struct gemm struct dot
{ {
float alpha = 1.0; float alpha = 1.0;
float beta = 0.0; float beta = 0.0;
...@@ -593,7 +593,7 @@ struct gemm ...@@ -593,7 +593,7 @@ struct gemm
return pack(f(self.alpha, "alpha"), f(self.beta, "beta")); return pack(f(self.alpha, "alpha"), f(self.beta, "beta"));
} }
std::string name() const { return "gemm"; } std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(2).same_type(); check_shapes{inputs, *this}.has(2).same_type();
......
...@@ -50,7 +50,7 @@ struct onnx_parser ...@@ -50,7 +50,7 @@ struct onnx_parser
{ {
add_generic_op("Add", op::add{}); add_generic_op("Add", op::add{});
add_generic_op("Div", op::div{}); add_generic_op("Div", op::div{});
add_generic_op("MatMul", op::gemm{}); add_generic_op("MatMul", op::dot{});
add_generic_op("Mul", op::mul{}); add_generic_op("Mul", op::mul{});
add_generic_op("Relu", op::activation{"relu"}); add_generic_op("Relu", op::activation{"relu"});
add_generic_op("Sub", op::sub{}); add_generic_op("Sub", op::sub{});
...@@ -274,11 +274,11 @@ struct onnx_parser ...@@ -274,11 +274,11 @@ struct onnx_parser
if(args.size() == 3) if(args.size() == 3)
{ {
uint64_t axis = 1; uint64_t axis = 1;
auto l3 = prog.add_instruction(op::gemm{alpha, beta}, l1, l2); auto l3 = prog.add_instruction(op::dot{alpha, beta}, l1, l2);
auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]); auto l4 = prog.add_instruction(op::broadcast{axis, l3->get_shape()}, args[2]);
return prog.add_instruction(op::add{}, l3, l4); return prog.add_instruction(op::add{}, l3, l4);
} }
return prog.add_instruction(op::gemm{alpha, beta}, l1, l2); return prog.add_instruction(op::dot{alpha, beta}, l1, l2);
} }
instruction_ref instruction_ref
......
...@@ -312,8 +312,8 @@ struct cpu_concat ...@@ -312,8 +312,8 @@ struct cpu_concat
struct cpu_gemm struct cpu_gemm
{ {
op::gemm op; op::dot op;
std::string name() const { return "cpu::gemm"; } std::string name() const { return "cpu::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
...@@ -592,7 +592,7 @@ struct cpu_apply ...@@ -592,7 +592,7 @@ struct cpu_apply
{ {
apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>(); apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>(); apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
apply_map["gemm"] = extend_op<cpu_gemm, op::gemm>(); apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
......
...@@ -22,7 +22,7 @@ namespace gpu { ...@@ -22,7 +22,7 @@ namespace gpu {
struct miopen_gemm struct miopen_gemm
{ {
op::gemm op; op::dot op;
std::string name() const { return "gpu::gemm"; } std::string name() const { return "gpu::gemm"; }
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
......
...@@ -65,7 +65,7 @@ struct miopen_apply ...@@ -65,7 +65,7 @@ struct miopen_apply
{ {
check_shape(s, apply_add(it)); check_shape(s, apply_add(it));
} }
else if(it->name() == "gemm") else if(it->name() == "dot")
{ {
check_shape(s, apply_gemm(it)); check_shape(s, apply_gemm(it));
} }
...@@ -165,7 +165,7 @@ struct miopen_apply ...@@ -165,7 +165,7 @@ struct miopen_apply
instruction_ref apply_gemm(instruction_ref ins) instruction_ref apply_gemm(instruction_ref ins)
{ {
auto&& op = any_cast<op::gemm>(ins->get_operator()); auto&& op = any_cast<op::dot>(ins->get_operator());
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(
ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output); ins, miopen_gemm{op}, ins->inputs().at(0), ins->inputs().at(1), output);
......
...@@ -618,7 +618,7 @@ void gemm_test() ...@@ -618,7 +618,7 @@ void gemm_test()
auto al = p.add_literal(migraph::literal{a_shape, a}); auto al = p.add_literal(migraph::literal{a_shape, a});
migraph::shape b_shape{migraph::shape::get_type<T>{}, {5, 3}}; migraph::shape b_shape{migraph::shape::get_type<T>{}, {5, 3}};
auto bl = p.add_literal(migraph::literal{b_shape, b}); auto bl = p.add_literal(migraph::literal{b_shape, b});
p.add_instruction(migraph::op::gemm{}, al, bl); p.add_instruction(migraph::op::dot{}, al, bl);
p.compile(migraph::cpu::cpu_target{}); p.compile(migraph::cpu::cpu_target{});
auto result = p.eval({}); auto result = p.eval({});
std::vector<T> results_vector(12); std::vector<T> results_vector(12);
......
...@@ -402,7 +402,7 @@ struct test_gemm ...@@ -402,7 +402,7 @@ struct test_gemm
migraph::program p; migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}}); auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}});
p.add_instruction(migraph::op::gemm{}, a, b); p.add_instruction(migraph::op::dot{}, a, b);
return p; return p;
} }
}; };
...@@ -414,7 +414,7 @@ struct test_gemm_ld ...@@ -414,7 +414,7 @@ struct test_gemm_ld
migraph::program p; migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}, {10, 1}}); auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}, {10, 1}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}, {20, 1}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}, {20, 1}});
p.add_instruction(migraph::op::gemm{}, a, b); p.add_instruction(migraph::op::dot{}, a, b);
return p; return p;
} }
}; };
...@@ -427,7 +427,7 @@ struct test_gemm_transposeb ...@@ -427,7 +427,7 @@ struct test_gemm_transposeb
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}}); auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}});
auto bt = p.add_instruction(migraph::op::transpose{{1, 0}}, b); auto bt = p.add_instruction(migraph::op::transpose{{1, 0}}, b);
p.add_instruction(migraph::op::gemm{}, a, bt); p.add_instruction(migraph::op::dot{}, a, bt);
return p; return p;
} }
}; };
...@@ -440,7 +440,7 @@ struct test_gemm_transposea ...@@ -440,7 +440,7 @@ struct test_gemm_transposea
auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}}); auto a = p.add_parameter("a", migraph::shape{migraph::shape::float_type, {5, 4}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {5, 3}});
auto at = p.add_instruction(migraph::op::transpose{{1, 0}}, a); auto at = p.add_instruction(migraph::op::transpose{{1, 0}}, a);
p.add_instruction(migraph::op::gemm{}, at, b); p.add_instruction(migraph::op::dot{}, at, b);
return p; return p;
} }
}; };
...@@ -454,7 +454,7 @@ struct test_gemm_transposeab ...@@ -454,7 +454,7 @@ struct test_gemm_transposeab
auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}}); auto b = p.add_parameter("b", migraph::shape{migraph::shape::float_type, {3, 5}});
auto at = p.add_instruction(migraph::op::transpose{{1, 0}}, a); auto at = p.add_instruction(migraph::op::transpose{{1, 0}}, a);
auto bt = p.add_instruction(migraph::op::transpose{{1, 0}}, b); auto bt = p.add_instruction(migraph::op::transpose{{1, 0}}, b);
p.add_instruction(migraph::op::gemm{}, at, bt); p.add_instruction(migraph::op::dot{}, at, bt);
return p; return p;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment