Commit 1898f87c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup

parent b829300c
...@@ -13,8 +13,6 @@ struct context; ...@@ -13,8 +13,6 @@ struct context;
struct miopen_quant_gemm struct miopen_quant_gemm
{ {
op::quant_dot op; op::quant_dot op;
mutable argument pack_0{};
mutable argument pack_1{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -32,18 +30,6 @@ struct miopen_quant_gemm ...@@ -32,18 +30,6 @@ struct miopen_quant_gemm
} }
}; };
// struct hip_pack
// {
// std::string name() const { return "gpu::gemm_pack"; }
// shape compute_shape(const std::vector<shape>& inputs) const;
// argument
// compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
// std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
// {
// return shapes.size() - 1;
// }
// };
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -98,7 +98,6 @@ struct miopen_apply ...@@ -98,7 +98,6 @@ struct miopen_apply
add_generic_op<hip_min>("min"); add_generic_op<hip_min>("min");
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax"); add_extend_op<miopen_softmax, op::softmax>("softmax");
...@@ -110,6 +109,7 @@ struct miopen_apply ...@@ -110,6 +109,7 @@ struct miopen_apply
add_lrn_op(); add_lrn_op();
add_convolution_op(); add_convolution_op();
add_quant_convolution_op(); add_quant_convolution_op();
add_quant_gemm_op();
add_pooling_op(); add_pooling_op();
add_batch_norm_inference_op(); add_batch_norm_inference_op();
} }
...@@ -172,6 +172,31 @@ struct miopen_apply ...@@ -172,6 +172,31 @@ struct miopen_apply
}); });
} }
void add_quant_gemm_op()
{
apply_map.emplace("quant_dot", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_dot>(ins->get_operator());
std::vector<instruction_ref> refs = ins->inputs();
// add additional arguments if need packing
if (refs.at(0)->get_shape().transposed())
{
auto pack_a = insert_allocation(refs.at(0), refs.at(0)->get_shape());
refs.push_back(pack_a);
}
if (!refs.at(1)->get_shape().transposed())
{
auto pack_b = insert_allocation(refs.at(1), refs.at(1)->get_shape());
refs.push_back(pack_b);
}
auto output = insert_allocation(ins, ins->get_shape());
refs.push_back(output);
return prog->replace_instruction(ins, miopen_quant_gemm{op}, refs);
});
}
void add_pooling_op() void add_pooling_op()
{ {
apply_map.emplace("pooling", [=](instruction_ref ins) { apply_map.emplace("pooling", [=](instruction_ref ins) {
......
...@@ -56,20 +56,15 @@ shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const ...@@ -56,20 +56,15 @@ shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> input_shapes(inputs); std::vector<shape> input_shapes(inputs);
input_shapes.pop_back(); input_shapes.pop_back();
// if(!inputs.at(1).transposed()) if(!inputs.at(1).transposed())
// { {
// if (pack_1.empty()) input_shapes.pop_back();
// { }
// pack_1 = allocate_gpu(inputs.at(1));
// } if(inputs.at(0).transposed())
// } {
// if(inputs.at(0).transposed()) input_shapes.pop_back();
// { }
// if (pack_0.empty())
// {
// pack_0 = allocate_gpu(inputs.at(0));
// }
// }
check_shapes{input_shapes}.not_broadcasted(); check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes); return op.compute_shape(input_shapes);
...@@ -89,37 +84,26 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -89,37 +84,26 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto arg_b = args.at(1);
std::size_t pack_arg_num = 0;
if(!transb) if(!transb)
{ {
// use the algorithm to pack A arg_b = args.at(args.size() - 2);
if(pack_1.empty()) ++pack_arg_num;
{ device::pack_a(ctx.get_stream().get(), arg_b, args[1]);
std::cout << "allocate pack_1" << std::endl;
pack_1 = allocate_gpu(args.at(1).get_shape());
}
// assert(!pack_1.empty());
device::pack_a(ctx.get_stream().get(), pack_1, args[1]);
auto pb = from_gpu(pack_1);
std::cout << "pb = " << pb << std::endl;
} }
// need to pack A in this scenario, use the algorithm to pack B in the // need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API // comment of the API
auto arg_a = args.at(0);
if(transa) if(transa)
{ {
if(pack_0.empty()) arg_a = args.at(args.size() - 2 - pack_arg_num);
{ ++pack_arg_num;
std::cout << "allocate pack_0" << std::endl; device::pack_b(ctx.get_stream().get(), arg_a, args[0]);
pack_0 = allocate_gpu(args.at(0).get_shape());
}
device::pack_b(ctx.get_stream().get(), pack_0, args[0]);
auto a = from_gpu(args[0]);
auto pa = from_gpu(pack_0);
std::cout << "a = " << a << std::endl;
std::cout << "pa = " << pa << std::endl;
} }
bool is_3inputs = (args.size() == 4); bool is_3inputs = (args.size() - pack_arg_num == 4);
int8_t beta = 0; int8_t beta = 0;
if(is_3inputs) if(is_3inputs)
{ {
...@@ -153,10 +137,10 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -153,10 +137,10 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
(!transb) ? to_pointer(pack_1) : to_pointer(args[1]), to_pointer(arg_b),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
transa ? to_pointer(pack_0) : to_pointer(args[0]), to_pointer(arg_a),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
...@@ -183,11 +167,11 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -183,11 +167,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
(!transb) ? to_pointer(pack_1) : to_pointer(args[1]), to_pointer(arg_b),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
k * n, k * n,
transa ? to_pointer(pack_0) : to_pointer(args[0]), to_pointer(arg_a),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
m * k, m * k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment