Commit d00704ee authored by Khalique's avatar Khalique
Browse files

Merge branch 'reduce-mean-fix' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into bert_ops

parents 46cdc2f1 8e967294
...@@ -126,9 +126,6 @@ struct program ...@@ -126,9 +126,6 @@ struct program
friend bool operator==(const program& x, const program& y); friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); } friend bool operator!=(const program& x, const program& y) { return !(x == y); }
std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::make_shared<std::vector<std::pair<float, float>>>();
private: private:
void assign(const program& p); void assign(const program& p);
......
...@@ -17,11 +17,12 @@ void quantize(program& prog); ...@@ -17,11 +17,12 @@ void quantize(program& prog);
// insert the capture operator for the inputs of each operator to be quantized // insert the capture operator for the inputs of each operator to be quantized
// to int8 // to int8
void capture_arguments(program& prog, std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func); const std::function<void(std::size_t, std::vector<argument>)>& func);
void capture_arguments(program& prog, const std::vector<std::string>& ins_names); std::shared_ptr<std::vector<std::pair<float, float>>>
void capture_arguments(program& prog); capture_arguments(program& prog, const std::vector<std::string>& ins_names);
std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -112,8 +112,7 @@ void program::assign(const program& p) ...@@ -112,8 +112,7 @@ void program::assign(const program& p)
{ {
impl->instructions.clear(); impl->instructions.clear();
} }
impl->ctx = p.impl->ctx; impl->ctx = p.impl->ctx;
int8_quant_params = p.int8_quant_params;
std::unordered_map<instruction_ref, instruction_ref> ins_map; std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(p))
......
...@@ -187,11 +187,6 @@ PYBIND11_MODULE(migraphx, m) ...@@ -187,11 +187,6 @@ PYBIND11_MODULE(migraphx, m)
migraphx::quantize(p, ins_names); migraphx::quantize(p, ins_names);
}); });
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); }); m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); });
m.def("capture_arguments", [](migraphx::program& p, const std::vector<std::string>& ins_names) {
migraphx::capture_arguments(p, ins_names);
});
m.def("capture_arguments", [](migraphx::program& p) { migraphx::capture_arguments(p); });
#ifdef HAVE_GPU #ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false); m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
...@@ -118,9 +118,9 @@ void quantize(program& prog) { quantize(prog, {"all"}); } ...@@ -118,9 +118,9 @@ void quantize(program& prog) { quantize(prog, {"all"}); }
// For the input of each input argument, we need to insert a // For the input of each input argument, we need to insert a
// capture operator to compute the scale and shift // capture operator to compute the scale and shift
void capture_arguments(program& prog, std::size_t capture_arguments(program& prog,
const std::vector<std::string>& ins_names, const std::vector<std::string>& ins_names,
const std::function<void(std::size_t, std::vector<argument>)>& func) const std::function<void(std::size_t, std::vector<argument>)>& func)
{ {
size_t num_quant_params = 0; size_t num_quant_params = 0;
...@@ -161,34 +161,45 @@ void capture_arguments(program& prog, ...@@ -161,34 +161,45 @@ void capture_arguments(program& prog,
instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args); instruction::replace(ins, ins->get_operator(), ins->get_shape(), new_args);
} }
// set one pair of parameter for each argument return num_quant_params;
prog.int8_quant_params->resize(num_quant_params, std::make_pair(-1.0f, -1.0f));
} }
void capture_arguments(program& prog, const std::vector<std::string>& ins_names) std::shared_ptr<std::vector<std::pair<float, float>>>
capture_arguments(program& prog, const std::vector<std::string>& ins_names)
{ {
auto calc_quant_params = [&](std::size_t ins_index, std::vector<migraphx::argument> args) { std::shared_ptr<std::vector<std::pair<float, float>>> int8_quant_params =
std::pair<float, float> param_pair{1.0f, 0.0f}; std::make_shared<std::vector<std::pair<float, float>>>();
std::shared_ptr<std::vector<float>> max_abs_vals = std::make_shared<std::vector<float>>();
auto calc_quant_params = [int8_quant_params, max_abs_vals](
std::size_t ins_index, std::vector<migraphx::argument> args) {
std::pair<float, float> param_pair{64.0f, 0.0f};
// scale and shift is need for only int8 type, and we do not // scale and shift is need for only int8 type, and we do not
// consider shift, so set shift to 0 // consider shift, so set shift to 0
std::vector<float> vec_val; std::vector<float> vec_val;
args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); }); args.front().visit([&](auto output) { vec_val.assign(output.begin(), output.end()); });
auto max_val = *std::max_element(vec_val.begin(), vec_val.end()); auto max_val = *std::max_element(vec_val.begin(), vec_val.end());
auto min_val = *std::min_element(vec_val.begin(), vec_val.end()); auto min_val = *std::min_element(vec_val.begin(), vec_val.end());
auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val)); auto max_abs = std::max(std::fabs(max_val), std::fabs(min_val));
max_abs_vals->at(ins_index) = std::max(max_abs_vals->at(ins_index), max_abs);
param_pair.first = 127.0f / max_abs; param_pair.first = 127.0f / max_abs_vals->at(ins_index);
(*prog.int8_quant_params)[ins_index] = param_pair; int8_quant_params->at(ins_index) = param_pair;
}; };
capture_arguments(prog, ins_names, calc_quant_params); auto num_params = capture_arguments(prog, ins_names, calc_quant_params);
int8_quant_params->resize(num_params, std::pair<float, float>(64.0f, 0.0f));
max_abs_vals->resize(num_params, 0.0f);
return int8_quant_params;
} }
void capture_arguments(program& prog) std::shared_ptr<std::vector<std::pair<float, float>>> capture_arguments(program& prog)
{ {
std::vector<std::string> ins_names = {"dot", "convolution"}; std::vector<std::string> ins_names = {"dot", "convolution"};
capture_arguments(prog, ins_names); return capture_arguments(prog, ins_names);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -245,8 +245,7 @@ void reduce_standard_impl(hipStream_t stream, ...@@ -245,8 +245,7 @@ void reduce_standard_impl(hipStream_t stream,
T init, T init,
Input read_input, Input read_input,
Output read_output, Output read_output,
std::size_t relements, std::size_t relements)
std::size_t stride)
{ {
hip_visit_all(result, arg)([&](auto output, auto input) { hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
...@@ -255,7 +254,7 @@ void reduce_standard_impl(hipStream_t stream, ...@@ -255,7 +254,7 @@ void reduce_standard_impl(hipStream_t stream,
const std::size_t block_size = compute_block_size(relements, max_block_size); const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size; const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride; const auto base_idx = out_idx * relements;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ { auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]); return read_input(input.data()[base_idx + j]);
}); });
...@@ -276,25 +275,15 @@ void reduce(hipStream_t stream, ...@@ -276,25 +275,15 @@ void reduce(hipStream_t stream,
{ {
auto&& output_shape = result.get_shape(); auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape(); auto&& input_shape = arg.get_shape();
assert(output_shape.lens().size() == input_shape.lens().size());
if(input_shape.standard() and output_shape.standard() and if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(), std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()), std::prev(output_shape.lens().end()),
input_shape.lens().begin())) input_shape.lens().begin()))
{ {
std::size_t stride = std::accumulate(input_shape.strides().begin(), reduce_standard_impl(
input_shape.strides().end(), stream, result, arg, op, init, read_input, read_output, input_shape.lens().back());
1,
std::multiplies<size_t>());
reduce_standard_impl(stream,
result,
arg,
op,
init,
read_input,
read_output,
input_shape.lens().back(),
stride);
} }
else else
{ {
......
...@@ -3792,6 +3792,18 @@ struct test_reduce_mean : verify_program<test_reduce_mean> ...@@ -3792,6 +3792,18 @@ struct test_reduce_mean : verify_program<test_reduce_mean>
}; };
}; };
struct test_reduce_mean2 : verify_program<test_reduce_mean2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 128, 768}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int> struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment