Unverified Commit d655ef50 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into scatter-op

parents 2cb895a5 eacf042e
...@@ -147,6 +147,7 @@ jobs: ...@@ -147,6 +147,7 @@ jobs:
os: os:
- ubuntu-16.04 - ubuntu-16.04
- ubuntu-18.04 - ubuntu-18.04
- ubuntu-20.04
configuration: configuration:
- debug - debug
- release - release
......
...@@ -53,6 +53,7 @@ add_library(migraphx ...@@ -53,6 +53,7 @@ add_library(migraphx
remap.cpp remap.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
rewrite_quantization.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
schedule.cpp schedule.cpp
serialize.cpp serialize.cpp
...@@ -94,6 +95,7 @@ register_migraphx_ops( ...@@ -94,6 +95,7 @@ register_migraphx_ops(
cosh cosh
cos cos
deconvolution deconvolution
dequantizelinear
div div
dot dot
elu elu
...@@ -132,6 +134,7 @@ register_migraphx_ops( ...@@ -132,6 +134,7 @@ register_migraphx_ops(
prelu prelu
quant_convolution quant_convolution
quant_dot quant_dot
quantizelinear
recip recip
reduce_max reduce_max
reduce_mean reduce_mean
......
...@@ -626,7 +626,8 @@ auto tree(M main_op, Ms... ms) ...@@ -626,7 +626,8 @@ auto tree(M main_op, Ms... ms)
if(idx != leafs.size()) if(idx != leafs.size())
return nullopt; return nullopt;
// Use explicit captures to workaround ICE on gcc // Use explicit captures to workaround ICE on gcc
bool found = sequence_c<sizeof...(Ms)>([&ms..., &ctx, &leafs](auto... is) { // Capture by value to workaround compile error on gcc 9
bool found = sequence_c<sizeof...(Ms)>([ms..., &ctx, &leafs](auto... is) {
return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)(); return fold(lazy_and{})(ctx.lazy_match(ms, leafs[is])...)();
}); });
if(not found) if(not found)
......
#ifndef MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_DEQUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct dequantizelinear
{
std::string name() const { return "dequantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
return {shape::float_type, inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto x_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
argument x_zero_point{{x.get_shape().type(), output_shape.lens()}, zeros.data()};
if(args.size() == 3)
{
x_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(x, x_zero_point)([&](auto input, auto zero_pts) {
visit_all(result, x_scale)([&](auto output, auto scales) {
par_for(output_shape.elements(), [&](auto i) {
output[i] = static_cast<double>(static_cast<int64_t>(input[i]) -
static_cast<int64_t>(zero_pts[i])) *
scales[i];
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#define MIGRAPHX_GUARD_OPERATORS_QUANTIZE_LINEAR_HPP
#include <array>
#include <migraphx/op/common.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/tune_axis.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct quantizelinear
{
std::string name() const { return "quantizelinear"; }
shape compute_shape(std::vector<shape> inputs) const
{
if(inputs.size() == 3)
{
return {inputs[2].type(), inputs[0].lens(), inputs[0].strides()};
}
return {shape::uint8_type, inputs[0].lens(), inputs[0].strides()};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
auto x = args.at(0);
auto y_scale = args.at(1);
std::vector<int8_t> zeros(output_shape.elements(), 0);
argument y_zero_point{output_shape, zeros.data()};
if(args.size() == 3)
{
y_zero_point = args.at(2);
}
argument result{output_shape};
visit_all(result, y_zero_point)([&](auto output, auto zero_pts) {
x.visit([&](auto input) {
y_scale.visit([&](auto scales) {
using quant_type = typename decltype(output)::value_type;
auto min_value = std::numeric_limits<quant_type>::min();
auto max_value = std::numeric_limits<quant_type>::max();
par_for(output_shape.elements(), [&](auto i) {
int64_t quantized = static_cast<int64_t>(std::round(input[i] / scales[i])) +
static_cast<int64_t>(zero_pts[i]);
output[i] = std::max(static_cast<int64_t>(min_value),
std::min(static_cast<int64_t>(max_value), quantized));
});
});
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_QUANTIZATION_HPP
#include <string>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Rewrite quantization ops to equivalent operators
*/
struct rewrite_quantization
{
std::string name() const { return "rewrite_quantization"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear> ...@@ -15,46 +15,49 @@ struct parse_dequantizelinear : op_parser<parse_dequantizelinear>
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const const std::vector<instruction_ref>& args) const
{ {
int axis = 1; int axis = 1;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i(); axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size()); auto n_dim = input_lens.size();
auto sub_zero_point = args[0]; instruction_ref x_scale;
if(args[1]->get_shape().elements() != 1)
{
auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
x_scale = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
}
else
{
x_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
args[1]);
}
if(args.size() == 3) if(args.size() == 3)
{ {
auto zero_point = args[2]; auto x_zero_point = args[2];
if(not(zero_point->get_shape().elements() == 1)) if(x_zero_point->get_shape().elements() != 1)
{ {
axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction( x_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point); make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
x_zero_point);
} }
else
auto zero_point_int32 = info.add_instruction( {
make_op("convert", {{"target_type", shape::int32_type}}), zero_point); x_zero_point = info.add_instruction(
auto sub_zero_point_int32 = info.add_instruction( make_op("multibroadcast", {{"output_lens", input_lens}}), x_zero_point);
make_op("convert", {{"target_type", shape::int32_type}}), sub_zero_point);
sub_zero_point =
info.add_broadcastable_binary_op("sub", sub_zero_point_int32, zero_point_int32);
} }
auto dequant_input = info.add_instruction( return info.add_instruction(
make_op("convert", {{"target_type", shape::float_type}}), sub_zero_point); make_op("dequantizelinear"), args[0], x_scale, x_zero_point);
auto scale = args[1];
if(not(scale->get_shape().elements() == 1))
{
axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale);
} }
return info.add_broadcastable_binary_op("mul", dequant_input, scale);
return info.add_instruction(make_op("dequantizelinear"), args[0], x_scale);
} }
}; };
......
...@@ -42,13 +42,23 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -42,13 +42,23 @@ struct parse_gemm : op_parser<parse_gemm>
// swap the last two elements // swap the last two elements
std::swap(*perm.rbegin(), *(perm.rbegin() + 1)); std::swap(*perm.rbegin(), *(perm.rbegin() + 1));
auto l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[0]) auto l1 = args[0];
: args[0];
if(alpha != 1.0f)
{
auto alpha_literal = info.add_literal(alpha);
auto alpha_l1 = info.add_broadcastable_binary_op("mul", alpha_literal, l1);
l1 = info.add_instruction(make_op("convert", {{"target_type", l1->get_shape().type()}}),
alpha_l1);
}
l1 = (transa) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), l1) : l1;
auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1]) auto l2 = (transb) ? info.add_instruction(make_op("transpose", {{"dims", perm}}), args[1])
: args[1]; : args[1];
if(args.size() == 3) if(args.size() == 3)
{ {
if(beta != 0.f && args[2]->get_shape().elements() > 0) if(beta != 0.0f && args[2]->get_shape().elements() > 0)
{ {
auto out_lens = l1->get_shape().lens(); auto out_lens = l1->get_shape().lens();
out_lens.back() = l2->get_shape().lens().back(); out_lens.back() = l2->get_shape().lens().back();
...@@ -59,12 +69,17 @@ struct parse_gemm : op_parser<parse_gemm> ...@@ -59,12 +69,17 @@ struct parse_gemm : op_parser<parse_gemm>
l3 = info.add_instruction( l3 = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]); make_op("multibroadcast", {{"output_lens", out_lens}}), args[2]);
} }
auto beta_literal = info.add_literal(beta);
auto beta_broadcast = info.add_instruction(
make_op("multibroadcast", {{"output_lens", out_lens}}), beta_literal);
l3 = info.add_instruction(make_op("mul"), l3, beta_broadcast);
return info.add_instruction( return info.add_instruction(
make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2, l3); make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2, l3);
} }
} }
return info.add_instruction(make_op("dot", {{"alpha", alpha}, {"beta", beta}}), l1, l2); return info.add_instruction(make_op("dot", {{"alpha", 1.0f}, {"beta", 1.0f}}), l1, l2);
} }
}; };
......
...@@ -12,83 +12,51 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear> ...@@ -12,83 +12,51 @@ struct parse_quantizelinear : op_parser<parse_quantizelinear>
{ {
std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; } std::vector<op_desc> operators() const { return {{"QuantizeLinear"}}; }
// y = saturate(round(x / y_scale) + zero_point)
instruction_ref parse(const op_desc& opd, instruction_ref parse(const op_desc& opd,
const onnx_parser& /*parser*/, const onnx_parser& /*parser*/,
const onnx_parser::node_info& info, const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const const std::vector<instruction_ref>& args) const
{ {
auto quant_type = shape::uint8_type;
int nargs = args.size();
int max_quant = 255;
int min_quant = 0;
if(nargs == 3)
quant_type = args[2]->get_shape().type();
if(quant_type == shape::int8_type)
{
max_quant = 127;
min_quant = -128;
}
auto max_arg = info.add_literal(max_quant);
auto min_arg = info.add_literal(min_quant);
int axis = 1; int axis = 1;
if(contains(info.attributes, "axis")) if(contains(info.attributes, "axis"))
axis = info.attributes.at("axis").i(); axis = info.attributes.at("axis").i();
auto input_lens = args[0]->get_shape().lens(); auto input_lens = args[0]->get_shape().lens();
int n_dim = static_cast<int>(input_lens.size()); auto n_dim = input_lens.size();
auto scale = args[1]; instruction_ref y_scale;
if(not(scale->get_shape().elements() == 1)) if(args[1]->get_shape().elements() != 1)
{ {
axis = tune_axis(n_dim, axis, opd.op_name); auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
scale = info.add_instruction( y_scale = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), scale); make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}), args[1]);
} }
else
auto div = info.add_broadcastable_binary_op("div", args[0], scale);
auto div_round = info.add_instruction(make_op("round"), div);
auto add_zero_point = div_round;
if(nargs == 3)
{ {
auto zero_point = args[2]; y_scale = info.add_instruction(make_op("multibroadcast", {{"output_lens", input_lens}}),
if(not(zero_point->get_shape().elements() == 1)) args[1]);
{
axis = tune_axis(n_dim, axis, opd.op_name);
zero_point = info.add_instruction(
make_op("broadcast", {{"axis", axis}, {"dims", input_lens}}), zero_point);
}
zero_point = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), zero_point);
add_zero_point = info.add_instruction(
make_op("convert", {{"target_type", shape::int32_type}}), add_zero_point);
add_zero_point = info.add_broadcastable_binary_op("add", add_zero_point, zero_point);
} }
auto s = add_zero_point->get_shape(); if(args.size() == 3)
const auto& lens = s.lens(); {
std::vector<int64_t> out_lens(lens.begin(), lens.end()); auto y_zero_point = args[2];
if(min_arg->get_shape() != s) if(y_zero_point->get_shape().elements() != 1)
{ {
min_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), auto tuned_axis = tune_axis(n_dim, axis, opd.op_name);
min_arg); y_zero_point = info.add_instruction(
make_op("broadcast", {{"axis", tuned_axis}, {"dims", input_lens}}),
y_zero_point);
} }
if(max_arg->get_shape() != s) else
{ {
max_arg = info.add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), y_zero_point = info.add_instruction(
max_arg); make_op("multibroadcast", {{"output_lens", input_lens}}), y_zero_point);
}
return info.add_instruction(make_op("quantizelinear"), args[0], y_scale, y_zero_point);
} }
auto saturated = info.add_instruction(make_op("clip"), add_zero_point, min_arg, max_arg); return info.add_instruction(make_op("quantizelinear"), args[0], y_scale);
return info.add_instruction(make_op("convert", {{"target_type", quant_type}}), saturated);
} }
}; };
......
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/program.hpp>
#include <migraphx/shape.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
void apply_quantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "quantizelinear");
auto x = ins->inputs()[0];
auto y_scale = ins->inputs()[1];
if(x->get_shape().type() != y_scale->get_shape().type())
{
x = m.insert_instruction(ins, make_op("convert", {{"target_type", shape::float_type}}), x);
}
auto div = m.insert_instruction(ins, make_op("div"), x, y_scale);
auto add_zero_point = m.insert_instruction(ins, make_op("round"), div);
if(ins->inputs().size() == 3)
{
auto zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
add_zero_point = m.insert_instruction(ins, make_op("add"), add_zero_point, zero_point);
}
int64_t max_quant = 0;
int64_t min_quant = 0;
ins->get_shape().visit_type([&](auto qt) {
max_quant = qt.max();
min_quant = qt.min();
});
auto s = add_zero_point->get_shape();
std::vector<int> min_data(s.elements(), min_quant);
std::vector<int> max_data(s.elements(), max_quant);
auto min_arg = m.add_literal(literal(s, min_data));
auto max_arg = m.add_literal(literal(s, max_data));
auto saturate = m.insert_instruction(ins, make_op("clip"), add_zero_point, min_arg, max_arg);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), saturate);
}
void apply_dequantizelinear(module& m, instruction_ref ins)
{
assert(ins->name() == "dequantizelinear");
auto x = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[0]);
auto x_scale = ins->inputs()[1];
if(ins->inputs().size() == 3)
{
auto x_zero_point = m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), ins->inputs()[2]);
x = m.insert_instruction(ins, make_op("sub"), x, x_zero_point);
}
m.replace_instruction(ins, make_op("mul"), x, x_scale);
}
void rewrite_quantization::apply(module& m) const
{
for(auto ins : iterator_for(m))
{
if(ins->name() == "quantizelinear")
{
apply_quantizelinear(m, ins);
}
else if(ins->name() == "dequantizelinear")
{
apply_dequantizelinear(m, ins);
}
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <migraphx/remap.hpp> #include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
...@@ -46,6 +47,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -46,6 +47,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end()); std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type); unsupported_types.erase(shape::type_t::float_type);
return {normalize_ops{}, return {normalize_ops{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
dead_code_elimination{}, dead_code_elimination{},
decompose{}, decompose{},
......
...@@ -123,6 +123,7 @@ add_library(migraphx_gpu ...@@ -123,6 +123,7 @@ add_library(migraphx_gpu
convert.cpp convert.cpp
convolution.cpp convolution.cpp
deconvolution.cpp deconvolution.cpp
device_name.cpp
eliminate_workspace.cpp eliminate_workspace.cpp
elu.cpp elu.cpp
fuse_ops.cpp fuse_ops.cpp
......
...@@ -2,9 +2,9 @@ ...@@ -2,9 +2,9 @@
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx_kernels.hpp> #include <migraphx_kernels.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
...@@ -12,36 +12,6 @@ namespace migraphx { ...@@ -12,36 +12,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class HipDeviceProp>
std::string get_arch_name(rank<0>, const HipDeviceProp& props)
{
return "gfx" + std::to_string(props.gcnArch);
}
template <class HipDeviceProp>
auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(props.gcnArchName))
{
return std::string(props.gcnArchName);
}
int get_device_id()
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
MIGRAPHX_THROW("No device");
return device;
}
std::string get_device_name()
{
hipDeviceProp_t props{};
auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(rank<1>{}, props);
}
template <class T> template <class T>
std::string generate_index_ints(const std::vector<T>& v) std::string generate_index_ints(const std::vector<T>& v)
{ {
......
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <class HipDeviceProp>
std::string get_arch_name(rank<0>, const HipDeviceProp& props)
{
return "gfx" + std::to_string(props.gcnArch);
}
template <class HipDeviceProp>
auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(props.gcnArchName))
{
return std::string(props.gcnArchName);
}
int get_device_id()
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
MIGRAPHX_THROW("No device");
return device;
}
std::string get_device_name()
{
hipDeviceProp_t props{};
auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(rank<1>{}, props);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_GPU_DEVICE_NAME_HPP
#define MIGRAPHX_GUARD_GPU_DEVICE_NAME_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
std::string get_device_name();
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_DEVICE_NAME_HPP
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <migraphx/remap.hpp> #include <migraphx/remap.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
#include <migraphx/simplify_algebra.hpp> #include <migraphx/simplify_algebra.hpp>
...@@ -59,6 +60,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -59,6 +60,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
normalize_ops{}, normalize_ops{},
decompose{}, decompose{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_quantization{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{}, simplify_reshapes{},
eliminate_identity{}, eliminate_identity{},
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <migraphx/gpu/kernel.hpp> #include <migraphx/gpu/kernel.hpp>
#include <migraphx/gpu/target.hpp> #include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp> #include <migraphx/gpu/compile_hip_code_object.hpp>
...@@ -74,19 +75,10 @@ migraphx::src_file make_src_file(const std::string& name, const std::string& con ...@@ -74,19 +75,10 @@ migraphx::src_file make_src_file(const std::string& name, const std::string& con
return {name, std::make_pair(content.data(), content.data() + content.size())}; return {name, std::make_pair(content.data(), content.data() + content.size())};
} }
std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
EXPECT(hipGetDevice(&device) == hipSuccess);
EXPECT(hipGetDeviceProperties(&props, device) == hipSuccess);
return "gfx" + std::to_string(props.gcnArch);
}
TEST_CASE(simple_compile_hip) TEST_CASE(simple_compile_hip)
{ {
auto binaries = migraphx::gpu::compile_hip_src( auto binaries = migraphx::gpu::compile_hip_src(
{make_src_file("main.cpp", write_2s)}, "", get_device_name()); {make_src_file("main.cpp", write_2s)}, "", migraphx::gpu::get_device_name());
EXPECT(binaries.size() == 1); EXPECT(binaries.size() == 1);
migraphx::argument input{{migraphx::shape::int8_type, {5}}}; migraphx::argument input{{migraphx::shape::int8_type, {5}}};
...@@ -103,7 +95,7 @@ TEST_CASE(simple_compile_hip) ...@@ -103,7 +95,7 @@ TEST_CASE(simple_compile_hip)
TEST_CASE(code_object_hip) TEST_CASE(code_object_hip)
{ {
auto binaries = migraphx::gpu::compile_hip_src( auto binaries = migraphx::gpu::compile_hip_src(
{make_src_file("main.cpp", add_2s_binary)}, "", get_device_name()); {make_src_file("main.cpp", add_2s_binary)}, "", migraphx::gpu::get_device_name());
EXPECT(binaries.size() == 1); EXPECT(binaries.size() == 1);
migraphx::shape input{migraphx::shape::int8_type, {5}}; migraphx::shape input{migraphx::shape::int8_type, {5}};
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -83,8 +84,8 @@ struct function ...@@ -83,8 +84,8 @@ struct function
} }
}; };
template <class Iterator> template <class Stream, class Iterator>
inline std::ostream& stream_range(std::ostream& s, Iterator start, Iterator last) inline Stream& stream_range(Stream& s, Iterator start, Iterator last)
{ {
if(start != last) if(start != last)
{ {
...@@ -94,22 +95,17 @@ inline std::ostream& stream_range(std::ostream& s, Iterator start, Iterator last ...@@ -94,22 +95,17 @@ inline std::ostream& stream_range(std::ostream& s, Iterator start, Iterator last
return s; return s;
} }
inline std::ostream& operator<<(std::ostream& s, std::nullptr_t) template <class Stream>
inline Stream& operator<<(Stream& s, std::nullptr_t)
{ {
s << "nullptr"; s << "nullptr";
return s; return s;
} }
template <class T> template <class Stream,
inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v) class Range,
{ class = typename std::enable_if<not std::is_convertible<Range, std::string>{}>::type>
s << "{ "; inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.begin(), v.end()))
stream_range(s, v.begin(), v.end());
s << "}";
return s;
}
inline std::ostream& operator<<(std::ostream& s, const std::vector<bool>& v)
{ {
s << "{ "; s << "{ ";
stream_range(s, v.begin(), v.end()); stream_range(s, v.begin(), v.end());
......
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
 
 
 
B B
\ No newline at end of file \ No newline at end of file
...@@ -23,4 +23,4 @@ ...@@ -23,4 +23,4 @@
 
 
 
B B
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment