Commit 98fd5e1d authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into eliminate-more-contiguous

parents f7a6d87f a1c7e7a5
......@@ -3,8 +3,6 @@
#include <migraphx/shape.hpp>
#include <migraphx/op/convert.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/convert.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -12,7 +10,7 @@ namespace gpu {
struct context;
struct hip_convert : unary_device<hip_convert, device::convert>
struct hip_convert
{
op::convert op;
......@@ -22,13 +20,15 @@ struct hip_convert : unary_device<hip_convert, device::convert>
return migraphx::reflect(self.op, f);
}
hip_convert(op::convert oper) : op(oper) {}
std::string name() const { return "gpu::convert"; }
shape compute_shape(std::vector<shape> inputs) const
shape compute_shape(std::vector<shape> inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
inputs.pop_back();
check_shapes{inputs}.packed();
return op.compute_shape(inputs);
return shapes.size() - 1;
}
};
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg);
void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_ROUND_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_ROUND_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void round(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SIGMOID_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SIGMOID_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sigmoid(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_int8_conv_pack
{
std::string name() const { return "gpu::int8_conv_pack"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_int8_gemm_pack_a
{
std::string name() const { return "gpu::int8_gemm_pack_a"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_int8_gemm_pack_b
{
std::string name() const { return "gpu::int8_gemm_pack_b"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -34,11 +34,11 @@ Result make_obj(F f, Ts... xs)
auto status = f(&x, xs...);
Result r{x};
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen call failed");
MIGRAPHX_THROW("MAKE_OBJ: MIOpen call failed");
return r;
}
inline tensor_descriptor make_tensor(const migraphx::shape& s)
inline tensor_descriptor make_tensor(const migraphx::shape& s, bool pack = false)
{
auto t = make_obj<tensor_descriptor>(&miopenCreateTensorDescriptor);
// Convert to ints
......@@ -49,13 +49,33 @@ inline tensor_descriptor make_tensor(const migraphx::shape& s)
d = miopenFloat;
else if(s.type() == shape::half_type)
d = miopenHalf;
else if(s.type() == shape::int32_type)
d = miopenInt32;
else if(s.type() == shape::int8_type)
{
if(pack)
{
// update the lens and corresponding strides
d = miopenInt8x4;
lens[1] = ((lens[1] + 3) / 4) * 4;
strides[0] = strides[1] * lens[1];
}
else
{
d = miopenInt8;
}
}
else
MIGRAPHX_THROW("Unsupported type");
{
MIGRAPHX_THROW("MAKE_TENSOR: unsupported type");
}
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
return t;
}
inline convolution_descriptor make_conv(const migraphx::op::convolution& op)
template <class T>
inline convolution_descriptor make_conv(const T& op)
{
auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor);
miopenConvolutionMode_t c_mode = miopenConvolution;
......
#ifndef MIGRAPHX_GUARD_RTGLIB_PACK_INT8_ARGS_HPP
#define MIGRAPHX_GUARD_RTGLIB_PACK_INT8_ARGS_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct pack_int8_args
{
std::string name() const { return "gpu::pack_int8_args"; }
void apply(program& p) const;
shape pack_int8_shape(const shape& s) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANT_CONVOLUTION_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANT_CONVOLUTION_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_quant_convolution
{
op::quant_convolution op;
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
// TODO: Add algo
return op::quant_convolution::reflect(self.op, f);
}
std::string name() const { return "gpu::quant_convolution"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
shape compile(context& ctx, const shape& output_shape, std::vector<shape> inputs);
void finalize(context& ctx, const shape& output_shape, std::vector<shape> inputs);
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
private:
shape pack_int8_shape(const shape& s) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_QUANT_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_QUANT_GEMM_HPP
#include <migraphx/shape.hpp>
#include <migraphx/op/quant_dot.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct rocblas_quant_gemm
{
op::quant_dot op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::quant_gemm"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void batch_not_transposed(const std::vector<std::size_t>& strides) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_ROUND_HPP
#define MIGRAPHX_GUARD_RTGLIB_ROUND_HPP
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/round.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct hip_round : unary_device<hip_round, device::round>
{
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SIGMOID_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIGMOID_HPP
#include <migraphx/shape.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/device/sigmoid.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_sigmoid
struct hip_sigmoid : unary_device<hip_sigmoid, device::sigmoid>
{
shared<activation_descriptor> ad;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return gpu::reflect(self.ad.get(), f);
}
std::string name() const { return "gpu::sigmoid"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
......
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).standard();
return inputs.at(0);
}
argument
miopen_int8_conv_pack::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto arg_desc = make_tensor(args[0].get_shape());
auto arg_desc_vec4 = make_tensor(args[0].get_shape(), true);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
arg_desc.get(),
args[0].implicit(),
&beta,
arg_desc_vec4.get(),
args[1].implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("INT8_CONV_PACK: transform input tensor failed");
}
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_int8_gemm_pack_a::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument
hip_int8_gemm_pack_a::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::int8_gemm_pack_a(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
shape hip_int8_gemm_pack_b::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument
hip_int8_gemm_pack_b::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::int8_gemm_pack_b(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -16,6 +16,7 @@
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/contiguous.hpp>
#include <migraphx/gpu/relu.hpp>
#include <migraphx/gpu/sigmoid.hpp>
......@@ -46,6 +47,7 @@
#include <migraphx/gpu/batchnorm.hpp>
#include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/gather.hpp>
......@@ -53,11 +55,13 @@
#include <migraphx/gpu/convert.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/reduce_sum.hpp>
#include <migraphx/gpu/round.hpp>
#include <migraphx/gpu/rsqrt.hpp>
#include <migraphx/gpu/sqrt.hpp>
#include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <utility>
#include <functional>
#include <algorithm>
......@@ -83,7 +87,6 @@ struct miopen_apply
void init()
{
this->last = instruction::get_output_alias(std::prev(prog->end()));
add_miopen_simple_op<miopen_sigmoid>("sigmoid", make_sigmoid);
add_miopen_simple_op<miopen_abs>("abs", make_abs);
add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu);
......@@ -109,12 +112,15 @@ struct miopen_apply
add_generic_op<hip_max>("max");
add_generic_op<hip_min>("min");
add_generic_op<hip_rsqrt>("rsqrt");
add_generic_op<hip_round>("round");
add_generic_op<hip_pow>("pow");
add_generic_op<hip_sqdiff>("sqdiff");
add_generic_op<hip_relu>("relu");
add_generic_op<hip_sign>("sign");
add_generic_op<hip_sigmoid>("sigmoid");
add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<rocblas_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<hip_softmax, op::softmax>("softmax");
......@@ -130,6 +136,8 @@ struct miopen_apply
add_lrn_op();
add_convolution_op();
add_quant_convolution_op();
// add_quant_dot_op();
add_pooling_op();
add_batch_norm_inference_op();
}
......@@ -176,6 +184,21 @@ struct miopen_apply
});
}
void add_quant_convolution_op()
{
apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
auto args = ins->inputs();
auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(ins, conv, args[0], args[1], workspace, output);
});
}
void add_pooling_op()
{
apply_map.emplace("pooling", [=](instruction_ref ins) {
......
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void pack_int8_args::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->name() == "gpu::quant_gemm")
{
auto inputs = ins->inputs();
bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed();
if(!transb)
{
auto packed_b = p.insert_instruction(ins, hip_allocate{inputs[1]->get_shape()});
auto output_b =
p.insert_instruction(ins, hip_int8_gemm_pack_a{}, {inputs[1], packed_b});
instruction::replace_argument(ins, inputs[1], output_b);
}
if(transa)
{
auto packed_a = p.insert_instruction(ins, hip_allocate{inputs[0]->get_shape()});
auto output_a =
p.insert_instruction(ins, hip_int8_gemm_pack_b{}, {inputs[0], packed_a});
instruction::replace_argument(ins, inputs[0], output_a);
}
}
else if(ins->name() == "gpu::quant_convolution")
{
auto inputs = ins->inputs();
auto packed_x =
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[0]->get_shape())});
auto output_x =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[0], packed_x});
instruction::replace_argument(ins, inputs[0], output_x);
auto packed_w =
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[1]->get_shape())});
auto output_w =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[1], packed_w});
instruction::replace_argument(ins, inputs[1], output_w);
}
}
}
shape pack_int8_args::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/quant_convolution.hpp>
#include <migraphx/gpu/device/convert.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_quant_convolution::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return op.compute_shape({inputs.at(0), inputs.at(1)});
}
argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
auto x_desc = make_tensor(args[0].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape);
float alpha = 1;
float beta = 0;
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
w_desc.get(),
args[1].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
}
return args[3];
}
shape miopen_quant_convolution::compile(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
shape workspace_shape{};
auto x_desc = make_tensor(inputs[0], true);
auto w_desc = make_tensor(inputs[1], true);
auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}};
auto arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
auto arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1;
miopenConvAlgoPerf_t perf;
auto status = miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(),
arg_vec4_x.implicit(),
w_desc.get(),
arg_vec4_w.implicit(),
cd.get(),
y_desc.get(),
y.implicit(),
1,
&algo_count,
&perf,
workspace.implicit(),
workspace_size,
false);
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: find convolution failed");
}
handle = ctx.get_stream().get_miopen();
algo = perf.fwd_algo;
return shape{shape::int8_type, {perf.memory}};
}
void miopen_quant_convolution::finalize(context& ctx,
const shape& output_shape,
std::vector<shape> inputs)
{
if(handle == ctx.get_stream().get_miopen())
return;
// Check that workspace hasn't changed
auto size = inputs.at(2).bytes();
auto ws = compile(ctx, output_shape, std::move(inputs));
if(ws.bytes() > size)
MIGRAPHX_THROW("Workspace has changed during finalization.");
}
shape miopen_quant_convolution::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_SHAPE: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
#include <fstream>
#include <iomanip>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> in_shapes(inputs);
in_shapes.pop_back();
check_shapes{in_shapes}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
return op.compute_shape(in_shapes);
}
void rocblas_quant_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("QUANT_DOT: batch size {" + to_string_range(strides) + "} is transposed!");
}
}
argument rocblas_quant_gemm::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
bool is_3inputs = (args.size() == 4);
int32_t beta = 0;
if(is_3inputs)
{
beta = op.beta;
}
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto alpha_r = as(op.alpha);
auto beta_r = as(beta);
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
assert(k % 4 == 0);
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1)
{
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
&beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
rocblas_datatype_i32_r,
rocblas_gemm_algo_standard,
0,
0,
nullptr,
nullptr);
}
else
{
rocblas_gemm_strided_batched_ex(
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
k * n,
to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
m * k,
&beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
m * n,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
m * n,
num_matrices,
rocblas_datatype_i32_r,
rocblas_gemm_algo_standard,
0,
0,
nullptr,
nullptr);
}
});
return is_3inputs ? args[3] : args[2];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/sigmoid.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_sigmoid::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_sigmoid::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(ctx.get_stream().get_miopen(),
ad.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -22,6 +22,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp>
......@@ -64,6 +65,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{},
adjust_allocation{},
dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{},
fuse_ops{&ctx},
dead_code_elimination{},
write_literals{&ctx},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment