Commit 00d5d880 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents 00d90ca8 f60c3815
...@@ -11,7 +11,7 @@ shape miopen_convolution::compute_shape(const std::vector<shape>& inputs) const ...@@ -11,7 +11,7 @@ shape miopen_convolution::compute_shape(const std::vector<shape>& inputs) const
check_shapes{inputs, *this}.has(4).standard(); check_shapes{inputs, *this}.has(4).standard();
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2); std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, *this}.max_ndims(5); check_shapes{conv_inputs, *this}.max_ndims(5);
return op.compute_shape(conv_inputs); return op.normalize_compute_shape(conv_inputs);
} }
inline shape reshape_if_1d(const shape& input) inline shape reshape_if_1d(const shape& input)
......
...@@ -11,6 +11,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,6 +11,10 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
#if __AMDGCN_WAVEFRONT_SIZE == 32
#define MIGRAPHX_NO_DPP
#endif
struct sum struct sum
{ {
template <class T, class U> template <class T, class U>
......
#include "migraphx/gpu/device/visit.hpp"
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/reverse.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument
reverse(hipStream_t stream, argument result, argument arg1, const std::vector<int64_t>& axes)
{
auto s = arg1.get_shape();
// auto lens = s.lens();
std::vector<std::size_t> axis_len(axes.begin(), axes.end());
shape sa{shape::float_type, axis_len};
std::size_t nelements = s.elements();
visit_all(result, arg1)([&](auto output1, auto input1) {
hip_visit_views(output1, input1, s)([&](auto output, auto input, auto hs) {
hip_visit_views(sa)([&](auto daxes) {
auto lens = hs.lens;
gs_launch(stream, nelements)([=](auto i) __device__ {
auto idx = hs.multi(i);
auto in_idx = idx;
for(auto axis : daxes.lens)
in_idx[axis] = lens[axis] - 1 - idx[axis];
output[idx] = input[in_idx];
});
});
});
});
return result;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -185,7 +185,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -185,7 +185,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and
wei.lens()[3] != 3 and contains({{1, 1}}, op.stride)) wei.lens()[3] != 3 and contains({{1, 1}}, op.stride))
return false; return false;
return contains({{0, 0}, {1, 1}, {2, 2}}, op.padding) and return contains({{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}}, op.padding) and
contains({{0, 0}, {1, 1}}, op.stride) and contains({{1, 1}}, op.dilation); contains({{0, 0}, {1, 1}}, op.stride) and contains({{1, 1}}, op.dilation);
} }
...@@ -568,7 +568,7 @@ struct miopen_conv_bias ...@@ -568,7 +568,7 @@ struct miopen_conv_bias
{ {
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(5);
// TODO: Check slices // TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.normalize_compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
...@@ -615,7 +615,7 @@ struct miopen_conv_bias_relu ...@@ -615,7 +615,7 @@ struct miopen_conv_bias_relu
{ {
check_shapes{inputs, *this}.has(5); check_shapes{inputs, *this}.has(5);
// TODO: Check slices // TODO: Check slices
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.normalize_compute_shape({inputs.at(0), inputs.at(1)});
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
......
...@@ -37,8 +37,12 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs) ...@@ -37,8 +37,12 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
} }
template <class T> template <class T>
void gemm_impl( void gemm_impl(context& ctx,
context& ctx, const shape& output_shape, const std::vector<argument>& args, T alpha, T beta) const shape& output_shape,
const std::vector<argument>& args,
T alpha,
T beta,
bool int8_x4_format)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
...@@ -67,6 +71,14 @@ void gemm_impl( ...@@ -67,6 +71,14 @@ void gemm_impl(
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
rocblas_gemm_flags flag = rocblas_gemm_flags_none;
#endif
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
...@@ -75,7 +87,7 @@ void gemm_impl( ...@@ -75,7 +87,7 @@ void gemm_impl(
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); }; auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0) if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{ {
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!"); MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
} }
...@@ -112,11 +124,7 @@ void gemm_impl( ...@@ -112,11 +124,7 @@ void gemm_impl(
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 flag);
rocblas_gemm_flags_pack_int8x4);
#else
0);
#endif
} }
else else
{ {
...@@ -149,11 +157,7 @@ void gemm_impl( ...@@ -149,11 +157,7 @@ void gemm_impl(
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 flag);
rocblas_gemm_flags_pack_int8x4);
#else
0);
#endif
} }
}); });
} }
...@@ -162,18 +166,20 @@ void gemm(context& ctx, ...@@ -162,18 +166,20 @@ void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta) float beta,
bool int8_x4_format)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format);
} }
void gemm(context& ctx, void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta) int32_t beta,
bool int8_x4_format)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format);
} }
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REVERSE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REVERSE_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
argument
reverse(hipStream_t stream, argument result, argument arg1, const std::vector<int64_t>& axes);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -19,11 +19,13 @@ template <class Op> ...@@ -19,11 +19,13 @@ template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
Op op; Op op;
bool int8_x4_format = true;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return migraphx::reflect(self.op, f); return pack_join(migraphx::reflect(self.op, f),
pack(f(self.int8_x4_format, "int8_x4_format")));
} }
std::string name() const std::string name() const
...@@ -43,22 +45,13 @@ struct rocblas_gemm ...@@ -43,22 +45,13 @@ struct rocblas_gemm
batch_not_transposed(inputs[0].strides()); batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides()); batch_not_transposed(inputs[1].strides());
std::size_t kdim = inputs[0].lens().size() - 1;
// k be multiple of 4
if(op.name() == "quant_dot" && (inputs[0].lens()[kdim] % 4) != 0)
{
MIGRAPHX_THROW("GPU_GEMM: size of A {" + to_string_range(inputs[0].lens()) +
"} and B {" + to_string_range(inputs[1].lens()) +
"} must be multiple of 4 for int8 type");
}
return op.compute_shape(in_shapes); return op.compute_shape(in_shapes);
} }
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
{ {
gemm(ctx, output_shape, args, op.alpha, op.beta); gemm(ctx, output_shape, args, op.alpha, op.beta, int8_x4_format);
return args.back(); return args.back();
} }
......
...@@ -13,12 +13,14 @@ void gemm(context& ctx, ...@@ -13,12 +13,14 @@ void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta); float beta,
bool int8_x4_format);
void gemm(context& ctx, void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta); int32_t beta,
bool int8_x4_format);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -94,7 +94,7 @@ inline convolution_descriptor make_conv(const T& op) ...@@ -94,7 +94,7 @@ inline convolution_descriptor make_conv(const T& op)
std::vector<int> stride(std::max(2, kdims), 1); std::vector<int> stride(std::max(2, kdims), 1);
std::vector<int> dilation(std::max(2, kdims), 1); std::vector<int> dilation(std::max(2, kdims), 1);
std::copy_backward(op.padding.begin(), op.padding.end(), padding.end()); std::copy_backward(op.padding.begin(), op.padding.begin() + kdims, padding.end());
std::copy_backward(op.stride.begin(), op.stride.end(), stride.end()); std::copy_backward(op.stride.begin(), op.stride.end(), stride.end());
std::copy_backward(op.dilation.begin(), op.dilation.end(), dilation.end()); std::copy_backward(op.dilation.begin(), op.dilation.end(), dilation.end());
...@@ -145,7 +145,7 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op) ...@@ -145,7 +145,7 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
std::vector<int> stride(std::max(2, kdims), 1); std::vector<int> stride(std::max(2, kdims), 1);
std::vector<int> lengths(std::max(2, kdims), 1); std::vector<int> lengths(std::max(2, kdims), 1);
std::copy_backward(op.padding.begin(), op.padding.end(), padding.end()); std::copy_backward(op.padding.begin(), op.padding.begin() + kdims, padding.end());
std::copy_backward(op.stride.begin(), op.stride.end(), stride.end()); std::copy_backward(op.stride.begin(), op.stride.end(), stride.end());
std::copy_backward(op.lengths.begin(), op.lengths.end(), lengths.end()); std::copy_backward(op.lengths.begin(), op.lengths.end(), lengths.end());
......
...@@ -13,7 +13,7 @@ namespace gpu { ...@@ -13,7 +13,7 @@ namespace gpu {
struct pack_int8_args struct pack_int8_args
{ {
std::string name() const { return "gpu::pack_int8_args"; } std::string name() const { return "gpu::pack_int8_args"; }
void apply(module& p) const; void apply(module& m) const;
shape pack_int8_shape(const shape& s) const; shape pack_int8_shape(const shape& s) const;
}; };
......
#ifndef MIGRAPHX_GUARD_RTGLIB_REVERSE_HPP
#define MIGRAPHX_GUARD_RTGLIB_REVERSE_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_reverse
{
op::reverse op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::reverse"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -55,7 +55,8 @@ struct miopen_apply ...@@ -55,7 +55,8 @@ struct miopen_apply
std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{}; std::unordered_map<std::string, std::function<instruction_ref(instruction_ref)>> apply_map{};
instruction_ref last{}; instruction_ref last{};
std::unordered_map<instruction_ref, std::string> prog_output_names{}; std::unordered_map<instruction_ref, std::string> prog_output_names{};
bool offload_copy = false; bool offload_copy = false;
bool int8_x4_format = true;
context& get_context() const context& get_context() const
{ {
...@@ -97,6 +98,13 @@ struct miopen_apply ...@@ -97,6 +98,13 @@ struct miopen_apply
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto& ctx = get_context();
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
offload_copy = (mod->name() == "main") ? pass->offload_copy : false; offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
create_output_names(); create_output_names();
...@@ -160,6 +168,7 @@ struct miopen_apply ...@@ -160,6 +168,7 @@ struct miopen_apply
add_extend_op("reduce_min"); add_extend_op("reduce_min");
add_extend_op("reduce_prod"); add_extend_op("reduce_prod");
add_extend_op("reduce_sum"); add_extend_op("reduce_sum");
add_extend_op("reverse");
add_extend_op("rnn_var_sl_last_output"); add_extend_op("rnn_var_sl_last_output");
add_extend_op("rnn_var_sl_shift_output"); add_extend_op("rnn_var_sl_shift_output");
add_extend_op("rnn_var_sl_shift_sequence"); add_extend_op("rnn_var_sl_shift_sequence");
...@@ -313,7 +322,8 @@ struct miopen_apply ...@@ -313,7 +322,8 @@ struct miopen_apply
} }
} }
return mod->replace_instruction(ins, rocblas_gemm<Op>{Op{op.alpha, beta}}, refs); return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{op.alpha, beta}, int8_x4_format}, refs);
}); });
} }
......
#include <iterator>
#include <migraphx/gpu/pack_int8_args.hpp> #include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp> #include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp> #include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/permutation.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
void pack_int8_args::apply(module& p) const static instruction_ref pad_ins(module& m, instruction_ref ins, int offset)
{ {
for(auto ins : iterator_for(p)) auto s = ins->get_shape();
auto lens = s.lens();
auto k = lens[lens.size() + offset];
auto pad_k = (k + 3) / 4 * 4;
auto pad_lens = lens;
pad_lens[lens.size() + offset] = pad_k;
std::vector<int64_t> pad_dims(lens.size() * 2, 0);
auto ret_ins = ins;
if(pad_k != k)
{
pad_dims[lens.size() + offset] = pad_k - k;
shape ps{s.type(), pad_lens};
auto ins_out =
m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(ps)}}));
auto pad = make_op("pad", {{"pads", pad_dims}});
ret_ins =
m.insert_instruction(std::next(ins), make_op("gpu::pad", pad.to_value()), ins, ins_out);
}
return ret_ins;
}
static std::vector<instruction_ref> pad_inputs(module& m, instruction_ref ins)
{
std::vector<instruction_ref> ret_inputs;
auto inputs = ins->inputs();
auto in0 = inputs.at(0);
auto sa = in0->get_shape();
bool transa = sa.transposed();
if(transa)
{
auto perm = find_permutation(sa);
auto val = in0->get_operator().to_value();
if(val.contains("dims"))
{
int offset = static_cast<int>(perm.back()) - static_cast<int>(perm.size());
auto t_in = in0->inputs().front();
auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>();
auto r_in = m.insert_instruction(ins, make_op("transpose", {{"dims", dims}}), p_in);
ret_inputs.push_back(r_in);
}
else
{
shape cs{in0->get_shape().type(), in0->get_shape().lens()};
auto con_out =
m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(cs)}}));
auto cin0 = m.insert_instruction(ins, make_op("gpu::contiguous"), in0, con_out);
ret_inputs.push_back(pad_ins(m, cin0, -1));
}
}
else
{
ret_inputs.push_back(pad_ins(m, in0, -1));
}
auto in1 = inputs.at(1);
auto sb = in1->get_shape();
bool transb = sb.transposed();
if(transb)
{
auto perm = find_permutation(sb);
auto val = in1->get_operator().to_value();
if(val.contains("dims"))
{
int offset = static_cast<int>(perm[perm.size() - 2]) - static_cast<int>(perm.size());
auto t_in = in1->inputs().front();
auto p_in = pad_ins(m, t_in, offset);
auto dims = val.at("dims").to_vector<int64_t>();
auto r_in = m.insert_instruction(ins, make_op("transpose", {{"dims", dims}}), p_in);
ret_inputs.push_back(r_in);
}
else
{
shape cs{in1->get_shape().type(), in1->get_shape().lens()};
auto con_out =
m.insert_instruction(ins, make_op("hip::allocate", {{"shape", to_value(cs)}}));
auto cin1 = m.insert_instruction(ins, make_op("gpu::contiguous"), in1, con_out);
ret_inputs.push_back(pad_ins(m, cin1, -2));
}
}
else
{
ret_inputs.push_back(pad_ins(m, in1, -2));
}
std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(ret_inputs));
return ret_inputs;
}
void pack_int8_args::apply(module& m) const
{
for(auto ins : iterator_for(m))
{ {
if(ins->name() == "gpu::quant_gemm") if(ins->name() == "gpu::quant_gemm")
{ {
auto val = ins->get_operator().to_value();
assert(val.contains("int8_x4_format"));
if(not val.at("int8_x4_format").to<bool>())
{
return;
}
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto lens = inputs.at(0)->get_shape().lens();
// gemm need the k to be multiple of 4, so need packing that dimension
auto old_inputs = inputs;
if((lens.back() % 4) != 0)
{
inputs = pad_inputs(m, ins);
}
bool transa = inputs[0]->get_shape().transposed(); bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed(); bool transb = inputs[1]->get_shape().transposed();
if(!transb) if(!transb)
{ {
auto packed_b = p.insert_instruction(ins, hip_allocate{inputs[1]->get_shape()}); auto packed_b = m.insert_instruction(
auto output_b = ins, make_op("hip::allocate", {{"shape", to_value(inputs[1]->get_shape())}}));
p.insert_instruction(ins, hip_int8_gemm_pack_a{}, {inputs[1], packed_b}); auto output_b = m.insert_instruction(
instruction::replace_argument(ins, inputs[1], output_b); ins, make_op("gpu::int8_gemm_pack_a"), {inputs[1], packed_b});
inputs[1] = output_b;
} }
if(transa) if(transa)
{ {
auto packed_a = p.insert_instruction(ins, hip_allocate{inputs[0]->get_shape()}); auto packed_a = m.insert_instruction(
auto output_a = ins, make_op("hip::allocate", {{"shape", to_value(inputs[0]->get_shape())}}));
p.insert_instruction(ins, hip_int8_gemm_pack_b{}, {inputs[0], packed_a}); auto output_a = m.insert_instruction(
instruction::replace_argument(ins, inputs[0], output_a); ins, make_op("gpu::int8_gemm_pack_b"), {inputs[0], packed_a});
inputs[0] = output_a;
}
if(inputs != old_inputs)
{
m.replace_instruction(ins, ins->get_operator(), inputs);
} }
} }
else if(ins->name() == "gpu::quant_convolution") else if(ins->name() == "gpu::quant_convolution")
{ {
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto packed_x = auto packed_x = m.insert_instruction(
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[0]->get_shape())}); ins,
make_op("hip::allocate",
{{"shape", to_value(pack_int8_shape(inputs[0]->get_shape()))}}));
auto output_x = auto output_x =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[0], packed_x}); m.insert_instruction(ins, make_op("gpu::int8_conv_pack"), {inputs[0], packed_x});
instruction::replace_argument(ins, inputs[0], output_x); instruction::replace_argument(ins, inputs[0], output_x);
auto packed_w = auto packed_w = m.insert_instruction(
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[1]->get_shape())}); ins,
make_op("hip::allocate",
{{"shape", to_value(pack_int8_shape(inputs[1]->get_shape()))}}));
auto output_w = auto output_w =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[1], packed_w}); m.insert_instruction(ins, make_op("gpu::int8_conv_pack"), {inputs[1], packed_w});
instruction::replace_argument(ins, inputs[1], output_w); instruction::replace_argument(ins, inputs[1], output_w);
} }
} }
......
...@@ -10,7 +10,7 @@ shape miopen_pooling::compute_shape(const std::vector<shape>& inputs) const ...@@ -10,7 +10,7 @@ shape miopen_pooling::compute_shape(const std::vector<shape>& inputs) const
check_shapes{inputs, *this}.has(2).standard(); check_shapes{inputs, *this}.has(2).standard();
std::vector<shape> pooling_input = {inputs.at(0)}; std::vector<shape> pooling_input = {inputs.at(0)};
check_shapes{pooling_input, *this}.max_ndims(5); check_shapes{pooling_input, *this}.max_ndims(5);
return op.compute_shape(pooling_input); return op.normalize_compute_shape(pooling_input);
} }
inline void reshape_if_1d(shape& input) inline void reshape_if_1d(shape& input)
......
...@@ -10,7 +10,7 @@ namespace gpu { ...@@ -10,7 +10,7 @@ namespace gpu {
shape miopen_quant_convolution::compute_shape(const std::vector<shape>& inputs) const shape miopen_quant_convolution::compute_shape(const std::vector<shape>& inputs) const
{ {
check_shapes{inputs, *this}.has(4).standard(); check_shapes{inputs, *this}.has(4).standard();
return op.compute_shape({inputs.at(0), inputs.at(1)}); return op.normalize_compute_shape({inputs.at(0), inputs.at(1)});
} }
argument miopen_quant_convolution::compute(context& ctx, argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
......
#include <migraphx/gpu/reverse.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/reverse.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_reverse::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.normalize_compute_shape(inputs);
}
argument hip_reverse::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
return device::reverse(ctx.get_stream().get(), args.back(), args[0], op.axes);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp> #include <migraphx/normalize_ops.hpp>
#include <migraphx/propagate_constant.hpp> #include <migraphx/propagate_constant.hpp>
...@@ -50,6 +52,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -50,6 +52,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::bool_type); unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type); unsupported_types.erase(shape::type_t::int8_type);
unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::tuple_type);
// clang-format off // clang-format off
return return
{ {
...@@ -61,10 +64,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -61,10 +64,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
dead_code_elimination{}, dead_code_elimination{},
insert_pad{},
dead_code_elimination{},
rewrite_batchnorm{}, rewrite_batchnorm{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
inline_module{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_common_subexpression{}, eliminate_common_subexpression{},
......
...@@ -205,7 +205,10 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>> ...@@ -205,7 +205,10 @@ struct ref_convolution : auto_register_op<ref_convolution<Op>>
} }
std::string name() const { return "ref::" + op.name(); } std::string name() const { return "ref::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, shape output_shape, std::vector<argument> args) const argument compute(context&, shape output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
...@@ -370,7 +373,10 @@ struct ref_im2col ...@@ -370,7 +373,10 @@ struct ref_im2col
} }
static std::string name() { return "ref::im2col"; } static std::string name() { return "ref::im2col"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
...@@ -471,7 +477,10 @@ struct ref_pooling : auto_register_op<ref_pooling<Op>> ...@@ -471,7 +477,10 @@ struct ref_pooling : auto_register_op<ref_pooling<Op>>
} }
std::string name() const { return "ref::pooling_" + Op::name(); } std::string name() const { return "ref::pooling_" + Op::name(); }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{output_shape};
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <migraphx/pass.hpp> #include <migraphx/pass.hpp>
#include <migraphx/auto_contiguous.hpp> #include <migraphx/auto_contiguous.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/normalize_ops.hpp> #include <migraphx/normalize_ops.hpp>
...@@ -18,6 +20,10 @@ std::string target::name() const { return "ref"; } ...@@ -18,6 +20,10 @@ std::string target::name() const { return "ref"; }
std::vector<pass> target::get_passes(migraphx::context&, const compile_options&) const std::vector<pass> target::get_passes(migraphx::context&, const compile_options&) const
{ {
return {normalize_ops{}, return {normalize_ops{},
eliminate_pad{},
dead_code_elimination{},
insert_pad{},
dead_code_elimination{},
rewrite_rnn{}, rewrite_rnn{},
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
......
...@@ -62,16 +62,7 @@ struct parse_conv : op_parser<parse_conv> ...@@ -62,16 +62,7 @@ struct parse_conv : op_parser<parse_conv>
calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h); calculate_padding(0, pads, input_dims[2], op.stride[0], op.dilation[0], weight_h);
calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w); calculate_padding(1, pads, input_dims[3], op.stride[1], op.dilation[1], weight_w);
if(pads[0] != pads[2] || pads[1] != pads[3]) op.padding = std::vector<size_t>(pads.begin(), pads.end());
{
std::vector<int64_t> padding = {0, 0, pads[0], pads[1], 0, 0, pads[2], pads[3]};
l0 = info.add_instruction(migraphx::make_op("pad", {{"pads", padding}}), l0);
}
else
{
op.padding[0] = pads[0];
op.padding[1] = pads[1];
}
} }
else if(pad_mode.find("VALID") != std::string::npos) else if(pad_mode.find("VALID") != std::string::npos)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment