Commit 1898f87c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

code backup

parent b829300c
......@@ -13,8 +13,6 @@ struct context;
struct miopen_quant_gemm
{
op::quant_dot op;
mutable argument pack_0{};
mutable argument pack_1{};
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -32,18 +30,6 @@ struct miopen_quant_gemm
}
};
// struct hip_pack
// {
// std::string name() const { return "gpu::gemm_pack"; }
// shape compute_shape(const std::vector<shape>& inputs) const;
// argument
// compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
// std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
// {
// return shapes.size() - 1;
// }
// };
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -98,7 +98,6 @@ struct miopen_apply
add_generic_op<hip_min>("min");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax");
......@@ -110,6 +109,7 @@ struct miopen_apply
add_lrn_op();
add_convolution_op();
add_quant_convolution_op();
add_quant_gemm_op();
add_pooling_op();
add_batch_norm_inference_op();
}
......@@ -172,6 +172,31 @@ struct miopen_apply
});
}
void add_quant_gemm_op()
{
apply_map.emplace("quant_dot", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_dot>(ins->get_operator());
std::vector<instruction_ref> refs = ins->inputs();
// add additional arguments if need packing
if (refs.at(0)->get_shape().transposed())
{
auto pack_a = insert_allocation(refs.at(0), refs.at(0)->get_shape());
refs.push_back(pack_a);
}
if (!refs.at(1)->get_shape().transposed())
{
auto pack_b = insert_allocation(refs.at(1), refs.at(1)->get_shape());
refs.push_back(pack_b);
}
auto output = insert_allocation(ins, ins->get_shape());
refs.push_back(output);
return prog->replace_instruction(ins, miopen_quant_gemm{op}, refs);
});
}
void add_pooling_op()
{
apply_map.emplace("pooling", [=](instruction_ref ins) {
......
......@@ -56,20 +56,15 @@ shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> input_shapes(inputs);
input_shapes.pop_back();
// if(!inputs.at(1).transposed())
// {
// if (pack_1.empty())
// {
// pack_1 = allocate_gpu(inputs.at(1));
// }
// }
// if(inputs.at(0).transposed())
// {
// if (pack_0.empty())
// {
// pack_0 = allocate_gpu(inputs.at(0));
// }
// }
if(!inputs.at(1).transposed())
{
input_shapes.pop_back();
}
if(inputs.at(0).transposed())
{
input_shapes.pop_back();
}
check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes);
......@@ -89,37 +84,26 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto arg_b = args.at(1);
std::size_t pack_arg_num = 0;
if(!transb)
{
// use the algorithm to pack A
if(pack_1.empty())
{
std::cout << "allocate pack_1" << std::endl;
pack_1 = allocate_gpu(args.at(1).get_shape());
}
// assert(!pack_1.empty());
device::pack_a(ctx.get_stream().get(), pack_1, args[1]);
auto pb = from_gpu(pack_1);
std::cout << "pb = " << pb << std::endl;
arg_b = args.at(args.size() - 2);
++pack_arg_num;
device::pack_a(ctx.get_stream().get(), arg_b, args[1]);
}
// need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API
auto arg_a = args.at(0);
if(transa)
{
if(pack_0.empty())
{
std::cout << "allocate pack_0" << std::endl;
pack_0 = allocate_gpu(args.at(0).get_shape());
}
device::pack_b(ctx.get_stream().get(), pack_0, args[0]);
auto a = from_gpu(args[0]);
auto pa = from_gpu(pack_0);
std::cout << "a = " << a << std::endl;
std::cout << "pa = " << pa << std::endl;
arg_a = args.at(args.size() - 2 - pack_arg_num);
++pack_arg_num;
device::pack_b(ctx.get_stream().get(), arg_a, args[0]);
}
bool is_3inputs = (args.size() == 4);
bool is_3inputs = (args.size() - pack_arg_num == 4);
int8_t beta = 0;
if(is_3inputs)
{
......@@ -153,10 +137,10 @@ argument miopen_quant_gemm::compute(context& ctx,
m,
k,
&alpha_r,
(!transb) ? to_pointer(pack_1) : to_pointer(args[1]),
to_pointer(arg_b),
rocblas_datatype_i8_r,
ldb,
transa ? to_pointer(pack_0) : to_pointer(args[0]),
to_pointer(arg_a),
rocblas_datatype_i8_r,
lda,
&beta_r,
......@@ -183,11 +167,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m,
k,
&alpha_r,
(!transb) ? to_pointer(pack_1) : to_pointer(args[1]),
to_pointer(arg_b),
rocblas_datatype_i8_r,
ldb,
k * n,
transa ? to_pointer(pack_0) : to_pointer(args[0]),
to_pointer(arg_a),
rocblas_datatype_i8_r,
lda,
m * k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment