Commit b82f53d9 authored by Khalique's avatar Khalique
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into conv_same_padding
parents dbb87db1 2c60e428
......@@ -18,6 +18,7 @@ add_library(migraphx
generate.cpp
instruction.cpp
program.cpp
quantization.cpp
shape.cpp
schedule.cpp
pass_manager.cpp
......
......@@ -67,13 +67,6 @@ void eliminate_contiguous::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
// skip the reshape operator for now, since there is a bug
// for the transpose followed by a reshape
if(ins->name() == "reshape")
{
continue;
}
// Make a copy so we can modify it while we iterate
auto args = ins->inputs();
for(auto arg : ins->inputs())
......
......@@ -103,6 +103,13 @@ struct check_shapes
return *this;
}
const check_shapes& standard_or_scalar() const
{
if(!this->all_of([](const shape& s) { return s.standard() or s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar or in standard layout");
return *this;
}
const check_shapes& packed() const
{
if(!this->all_of([](const shape& s) { return s.packed(); }))
......
......@@ -24,11 +24,12 @@ struct binary : op_name<Derived>
return {s0.type(), s0.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0], args[1])([&](auto output, auto input1, auto input2) {
if(input1.get_shape().standard() and input2.get_shape().standard())
if(input1.get_shape().packed() and input2.get_shape().packed())
{
std::transform(input1.begin(),
input1.end(),
......@@ -44,6 +45,7 @@ struct binary : op_name<Derived>
});
}
});
return result;
}
};
......
#ifndef MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#define MIGRAPHX_GUARD_OPERATORS_CONVERT_HPP
#include <array>
#include <migraphx/op/unary.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct convert : unary<convert>
{
shape::type_t target_type = shape::half_type;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.target_type, "target_type"));
}
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
return {target_type, inputs.at(0).lens(), inputs.at(0).strides()};
}
auto apply() const
{
return [](auto x) { return x; };
}
convert(shape::type_t t) : target_type{t} {}
convert() {}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -29,7 +29,7 @@ struct reshape
std::string name() const { return "reshape"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
check_shapes{inputs, *this}.has(1).standard();
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.begin(), dims.end());
auto n_neg_dims = std::count(dims.begin(), dims.end(), -1);
......
......@@ -29,6 +29,7 @@ struct squeeze
std::string name() const { return "squeeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
......
......@@ -23,25 +23,31 @@ struct unary : op_name<Derived>
return {s.type(), s.lens()};
}
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
visit_all(result, args[0])([&](auto output, auto input) {
if(input.get_shape().standard())
{
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
}
else
{
result.visit([&](auto output) {
args[0].visit([&](auto input) {
if(input.get_shape().packed())
{
std::transform(input.begin(),
input.end(),
output.begin(),
static_cast<const Derived&>(*this).apply());
return result;
}
shape_for_each(output.get_shape(), [&](const auto& idx) {
output(idx.begin(), idx.end()) =
static_cast<const Derived&>(*this).apply()(input(idx.begin(), idx.end()));
});
}
return result;
});
});
return result;
}
};
......
......@@ -29,6 +29,7 @@ struct unsqueeze
std::string name() const { return "unsqueeze"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard_or_scalar();
auto input_shape = inputs[0];
auto type = input_shape.type();
auto old_lens = input_shape.lens();
......
......@@ -15,6 +15,7 @@
#include <migraphx/op/common.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/contiguous.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/cosh.hpp>
#include <migraphx/op/cos.hpp>
......
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANTIZATION_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct program;
void quantize(program& prog, const std::vector<std::string>& ins_names);
void quantize(program& prog);
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -177,7 +177,7 @@ void memory_coloring_impl::build()
void memory_coloring_impl::rewrite()
{
std::vector<std::size_t> dims;
dims.push_back(required_bytes / sizeof(float));
dims.push_back((required_bytes + sizeof(float) - 1) / sizeof(float));
shape s = {shape::float_type, dims};
instruction_ref scratch_param = p_program->add_parameter("scratch", s);
for(auto ins : iterator_for(*p_program))
......
......@@ -2,6 +2,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <migraphx/program.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/stringutils.hpp>
......@@ -181,6 +182,10 @@ PYBIND11_MODULE(migraphx, m)
});
m.def("generate_argument", &migraphx::generate_argument, py::arg("s"), py::arg("seed") = 0);
m.def("quantize", [](migraphx::program& p, std::vector<std::string>& ins_names) {
migraphx::quantize(p, ins_names);
});
m.def("quantize", [](migraphx::program& p) { migraphx::quantize(p, {"all"}); });
#ifdef HAVE_GPU
m.def("allocate_gpu", &migraphx::gpu::allocate_gpu, py::arg("s"), py::arg("host") = false);
......
#include <migraphx/quantization.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
instruction_ref insert_fp16(program& prog,
instruction_ref& ins,
shape::type_t type,
std::unordered_map<instruction_ref, instruction_ref>& map_fp16)
{
if(map_fp16.count(ins) > 0)
{
return map_fp16[ins];
}
assert(ins->get_shape().type() == shape::float_type ||
ins->get_shape().type() == shape::double_type);
instruction_ref ins_fp16{};
ins_fp16 = prog.insert_instruction(std::next(ins), op::convert{type}, ins);
map_fp16[ins] = ins_fp16;
return ins_fp16;
}
void quantize(program& prog, const std::vector<std::string>& ins_names)
{
std::unordered_map<instruction_ref, instruction_ref> map_fp16;
for(auto ins : iterator_for(prog))
{
// all indicates every instruction is converted
if((not contains(ins_names, "all")) and (not contains(ins_names, ins->name())))
{
continue;
}
shape::type_t orig_type = ins->get_shape().type();
// process all inputs, if input is a fp32 or fp64, convert it
// to a fp16 by adding a convert operator.
auto inputs = ins->inputs();
std::vector<instruction_ref> converted_inputs;
for(auto input : inputs)
{
auto s = input->get_shape();
if(s.type() == shape::float_type || s.type() == shape::double_type)
{
// if the input is a convert operator, uses its input
// as its current input
instruction_ref input_fp16{};
if(input->name() == "convert")
{
input_fp16 = input->inputs().front();
}
else
{
input_fp16 = insert_fp16(prog, input, shape::half_type, map_fp16);
}
converted_inputs.push_back(input_fp16);
}
else
{
converted_inputs.push_back(input);
}
}
// no change for the input, go to the next instruction
if(inputs == converted_inputs)
{
continue;
}
auto op = ins->get_operator();
auto ins_shape = compute_shape(op, converted_inputs);
if(ins_shape.type() != orig_type)
{
// insert another convert instruction to convert it back
if(ins == std::prev(prog.end()))
{
prog.add_instruction(op::convert{orig_type}, ins);
}
else
{
// check the dead code case to avoid assert
bool output_empty = ins->outputs().empty();
auto ins_orig_type =
prog.insert_instruction(std::next(ins), op::convert{orig_type}, ins);
if(!output_empty)
{
prog.replace_instruction(ins, ins_orig_type);
}
}
}
prog.replace_instruction(ins, op, converted_inputs);
}
}
void quantize(program& prog) { quantize(prog, {"all"}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -14,7 +14,9 @@ bool is_reshaper(instruction_ref ins)
// clang-format off
static const std::unordered_set<std::string> names = {
"reshape",
"contiguous"
"contiguous",
"squeeze",
"unsqueeze"
};
// clang-format on
return contains(names, ins->name());
......@@ -45,6 +47,9 @@ void simplify_reshapes::apply(program& p) const
auto end = std::prev(p.end());
for(auto ins : iterator_for(p))
{
if(ins == end and ins->name() == "contiguous")
continue;
// Skip possible dead instructions
if(ins->outputs().empty() and ins != end)
continue;
if(is_reshaper(ins))
......@@ -94,13 +99,6 @@ void simplify_reshapes::apply(program& p) const
p.replace_instruction(ins, t->inputs().front());
}
}
// Replace all reshapes with as_shape
for(auto ins : iterator_for(p))
{
if(ins->name() != "reshape")
continue;
p.replace_instruction(ins, op::as_shape{ins->get_shape()}, ins->inputs());
}
}
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -27,6 +27,7 @@ add_library(migraphx_device
device/add_relu.cpp
device/contiguous.cpp
device/logsoftmax.cpp
device/convert.cpp
device/mul.cpp
device/concat.cpp
device/pad.cpp
......
......@@ -2,7 +2,6 @@
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <algorithm>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void convert(hipStream_t stream, const argument& result, const argument& arg)
{
result.visit([&](auto output) {
arg.visit([&](auto input) {
const auto* input_ptr = device_cast(input.data());
auto* output_ptr = device_cast(output.data());
gs_launch(stream,
result.get_shape().elements())([=](auto i) { output_ptr[i] = input_ptr[i]; });
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_CONVERT_HPP
#define MIGRAPHX_GUARD_RTGLIB_CONVERT_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_convert : unary_device<hip_convert, device::convert>
{
op::convert op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
hip_convert(op::convert oper) : op(oper) {}
shape compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_CONVERT_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_CONVERT_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void convert(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment