"vscode:/vscode.git/clone" did not exist on "4976eb0c09a114e9facc06654cabf91fdb196532"
Commit eab3cafb authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge changes from int8_miopen_call

parents 5edf61dc 93b7fb54
...@@ -32,7 +32,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") ...@@ -32,7 +32,9 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
endif() endif()
endif() endif()
if(CMAKE_CXX_COMPILER MATCHES ".*hcc") include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("--cuda-host-only -x hip" HAS_HIP)
if(HAS_HIP)
message(STATUS "Enable miopen backend") message(STATUS "Enable miopen backend")
set(MIGRAPHX_ENABLE_GPU On CACHE BOOL "") set(MIGRAPHX_ENABLE_GPU On CACHE BOOL "")
else() else()
......
...@@ -37,7 +37,7 @@ add_library(migraphx_device ...@@ -37,7 +37,7 @@ add_library(migraphx_device
device/pad.cpp device/pad.cpp
device/gather.cpp device/gather.cpp
device/sub.cpp device/sub.cpp
device/pack.cpp device/int8_gemm_pack.cpp
device/div.cpp device/div.cpp
device/clip.cpp device/clip.cpp
device/reduce_sum.cpp device/reduce_sum.cpp
...@@ -85,9 +85,12 @@ add_library(migraphx_gpu ...@@ -85,9 +85,12 @@ add_library(migraphx_gpu
lrn.cpp lrn.cpp
schedule_model.cpp schedule_model.cpp
adjust_allocation.cpp adjust_allocation.cpp
pack_int8_args.cpp
clip.cpp clip.cpp
reduce_sum.cpp reduce_sum.cpp
reduce_mean.cpp reduce_mean.cpp
int8_gemm_pack.cpp
int8_conv_pack.cpp
) )
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu) set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
rocm_clang_tidy_check(migraphx_gpu) rocm_clang_tidy_check(migraphx_gpu)
......
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/gpu/device/pack.hpp> #include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/hip.hpp> #include <migraphx/gpu/hip.hpp>
namespace migraphx { namespace migraphx {
...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -11,7 +11,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg) void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg)
{ {
auto comp_shape = arg.get_shape(); auto comp_shape = arg.get_shape();
auto out_lens = comp_shape.lens(); auto out_lens = comp_shape.lens();
...@@ -38,7 +38,7 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg) ...@@ -38,7 +38,7 @@ void pack_a(hipStream_t stream, const argument& result, const argument& arg)
}); });
} }
void pack_b(hipStream_t stream, const argument& result, const argument& arg) void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg)
{ {
auto trans_shape = arg.get_shape(); auto trans_shape = arg.get_shape();
auto out_lens = trans_shape.lens(); auto out_lens = trans_shape.lens();
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PACK_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PACK_HPP #define MIGRAPHX_GUARD_RTGLIB_DEVICE_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp> #include <migraphx/argument.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
...@@ -10,11 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,11 +10,9 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void pack_a(hipStream_t stream, const argument& result, const argument& arg); void int8_gemm_pack_a(hipStream_t stream, const argument& result, const argument& arg);
void pack_b(hipStream_t stream, const argument& result, const argument& arg); void int8_gemm_pack_b(hipStream_t stream, const argument& result, const argument& arg);
void sync_stream(hipStream_t stream);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct miopen_int8_conv_pack
{
std::string name() const { return "gpu::int8_conv_pack"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_int8_gemm_pack_a
{
std::string name() const { return "gpu::int8_gemm_pack_a"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
struct hip_int8_gemm_pack_b
{
std::string name() const { return "gpu::int8_gemm_pack_b"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_PACK_INT8_ARGS_HPP
#define MIGRAPHX_GUARD_RTGLIB_PACK_INT8_ARGS_HPP
#include <migraphx/program.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct pack_int8_args
{
std::string name() const { return "gpu::pack_int8_args"; }
void apply(program& p) const;
shape pack_int8_shape(const shape& s) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -17,8 +17,6 @@ struct miopen_quant_convolution ...@@ -17,8 +17,6 @@ struct miopen_quant_convolution
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr; miopenHandle_t handle = nullptr;
argument arg_vec4_x{};
argument arg_vec4_w{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -39,7 +37,7 @@ struct miopen_quant_convolution ...@@ -39,7 +37,7 @@ struct miopen_quant_convolution
} }
private: private:
shape pack_int8_shape(shape& s); shape pack_int8_shape(const shape& s) const;
}; };
} // namespace gpu } // namespace gpu
......
...@@ -10,7 +10,7 @@ namespace gpu { ...@@ -10,7 +10,7 @@ namespace gpu {
struct context; struct context;
struct miopen_quant_gemm struct rocblas_quant_gemm
{ {
op::quant_dot op; op::quant_dot op;
...@@ -24,6 +24,7 @@ struct miopen_quant_gemm ...@@ -24,6 +24,7 @@ struct miopen_quant_gemm
shape compute_shape(const std::vector<shape>& inputs) const; shape compute_shape(const std::vector<shape>& inputs) const;
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const; compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void batch_not_transposed(const std::vector<std::size_t>& strides) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_int8_conv_pack::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).standard();
return inputs.at(0);
}
argument
miopen_int8_conv_pack::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
auto arg_desc = make_tensor(args[0].get_shape());
auto arg_desc_vec4 = make_tensor(args[0].get_shape(), true);
float alpha = 1;
float beta = 0;
// pack input to vec4 format
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
arg_desc.get(),
args[0].implicit(),
&beta,
arg_desc_vec4.get(),
args[1].implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("INT8_CONV_PACK: transform input tensor failed");
}
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/device/int8_gemm_pack.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_int8_gemm_pack_a::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument
hip_int8_gemm_pack_a::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::int8_gemm_pack_a(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
shape hip_int8_gemm_pack_b::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{{inputs.at(0)}, *this}.has(1).not_broadcasted().packed();
return inputs.at(0);
}
argument
hip_int8_gemm_pack_b::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::int8_gemm_pack_b(ctx.get_stream().get(), args[1], args[0]);
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -59,6 +59,7 @@ ...@@ -59,6 +59,7 @@
#include <migraphx/gpu/reduce_mean.hpp> #include <migraphx/gpu/reduce_mean.hpp>
#include <migraphx/gpu/pow.hpp> #include <migraphx/gpu/pow.hpp>
#include <migraphx/gpu/sqdiff.hpp> #include <migraphx/gpu/sqdiff.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <utility> #include <utility>
#include <functional> #include <functional>
#include <algorithm> #include <algorithm>
...@@ -115,6 +116,7 @@ struct miopen_apply ...@@ -115,6 +116,7 @@ struct miopen_apply
add_generic_op<hip_sqdiff>("sqdiff"); add_generic_op<hip_sqdiff>("sqdiff");
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<rocblas_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<hip_softmax, op::softmax>("softmax"); add_extend_op<hip_softmax, op::softmax>("softmax");
...@@ -131,7 +133,7 @@ struct miopen_apply ...@@ -131,7 +133,7 @@ struct miopen_apply
add_lrn_op(); add_lrn_op();
add_convolution_op(); add_convolution_op();
add_quant_convolution_op(); add_quant_convolution_op();
add_quant_dot_op(); // add_quant_dot_op();
add_pooling_op(); add_pooling_op();
add_batch_norm_inference_op(); add_batch_norm_inference_op();
} }
...@@ -182,32 +184,14 @@ struct miopen_apply ...@@ -182,32 +184,14 @@ struct miopen_apply
{ {
apply_map.emplace("quant_convolution", [=](instruction_ref ins) { apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator()); auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)}; auto conv = miopen_quant_convolution{op, make_conv(op)};
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs())); auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
auto args = ins->inputs();
auto workspace = insert_allocation(ins, ws, "workspace"); auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction( return prog->replace_instruction(ins, conv, args[0], args[1], workspace, output);
ins, conv, ins->inputs().at(0), ins->inputs().at(1), workspace, output);
});
}
void add_quant_dot_op()
{
apply_map.emplace("quant_dot", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_dot>(ins->get_operator());
auto inputs = ins->inputs();
auto in_shapes = to_shapes(inputs);
auto pack_a = insert_allocation(ins, in_shapes[0], "pack_a");
auto pack_b = insert_allocation(ins, in_shapes[1], "pack_b");
auto output = insert_allocation(ins, ins->get_shape());
inputs.push_back(pack_a);
inputs.push_back(pack_b);
inputs.push_back(output);
return prog->replace_instruction(ins, miopen_quant_gemm{op}, inputs);
}); });
} }
......
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
void pack_int8_args::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->name() == "gpu::quant_gemm")
{
auto inputs = ins->inputs();
bool transa = inputs[0]->get_shape().transposed();
bool transb = inputs[1]->get_shape().transposed();
if(!transb)
{
auto packed_b = p.insert_instruction(ins, hip_allocate{inputs[1]->get_shape()});
auto output_b =
p.insert_instruction(ins, hip_int8_gemm_pack_a{}, {inputs[1], packed_b});
instruction::replace_argument(ins, inputs[1], output_b);
}
if(transa)
{
auto packed_a = p.insert_instruction(ins, hip_allocate{inputs[0]->get_shape()});
auto output_a =
p.insert_instruction(ins, hip_int8_gemm_pack_b{}, {inputs[0], packed_a});
instruction::replace_argument(ins, inputs[0], output_a);
}
}
else if(ins->name() == "gpu::quant_convolution")
{
auto inputs = ins->inputs();
auto packed_x =
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[0]->get_shape())});
auto output_x =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[0], packed_x});
instruction::replace_argument(ins, inputs[0], output_x);
auto packed_w =
p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[1]->get_shape())});
auto output_w =
p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[1], packed_w});
instruction::replace_argument(ins, inputs[1], output_w);
}
}
}
shape pack_int8_args::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -16,54 +16,26 @@ argument miopen_quant_convolution::compute(context& ctx, ...@@ -16,54 +16,26 @@ argument miopen_quant_convolution::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
auto x_desc = make_tensor(args[0].get_shape()); auto x_desc = make_tensor(args[0].get_shape(), true);
auto x_desc_vec4 = make_tensor(args[0].get_shape(), true); auto w_desc = make_tensor(args[1].get_shape(), true);
auto w_desc = make_tensor(args[1].get_shape()); auto y_desc = make_tensor(output_shape);
auto w_desc_vec4 = make_tensor(args[1].get_shape(), true);
auto y_desc = make_tensor(output_shape);
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
// pack input to vec4 format auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
auto status = miopenTransformTensor(ctx.get_stream().get_miopen(), &alpha,
&alpha, x_desc.get(),
x_desc.get(), args[0].implicit(),
args[0].implicit(), w_desc.get(),
&beta, args[1].implicit(),
x_desc_vec4.get(), cd.get(),
arg_vec4_x.implicit()); algo,
if(status != miopenStatusSuccess) &beta,
{ y_desc.get(),
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform input tensor failed"); args[3].implicit(),
} args[2].implicit(),
args[2].get_shape().bytes());
// pack input to vec4 format
status = miopenTransformTensor(ctx.get_stream().get_miopen(),
&alpha,
w_desc.get(),
args[1].implicit(),
&beta,
w_desc_vec4.get(),
arg_vec4_w.implicit());
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: transform weight tensor failed");
}
status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc_vec4.get(),
arg_vec4_x.implicit(),
w_desc_vec4.get(),
arg_vec4_w.implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
{ {
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed"); MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
...@@ -90,10 +62,10 @@ shape miopen_quant_convolution::compile(context& ctx, ...@@ -90,10 +62,10 @@ shape miopen_quant_convolution::compile(context& ctx,
&workspace_size); &workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0]))); auto arg_vec4_x = to_gpu(generate_argument(pack_int8_shape(inputs[0])));
arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1]))); auto arg_vec4_w = to_gpu(generate_argument(pack_int8_shape(inputs[1])));
auto y = allocate_gpu(output_shape); auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape); auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1; int algo_count = 1;
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
...@@ -133,7 +105,7 @@ void miopen_quant_convolution::finalize(context& ctx, ...@@ -133,7 +105,7 @@ void miopen_quant_convolution::finalize(context& ctx,
MIGRAPHX_THROW("Workspace has changed during finalization."); MIGRAPHX_THROW("Workspace has changed during finalization.");
} }
shape miopen_quant_convolution::pack_int8_shape(shape& s) shape miopen_quant_convolution::pack_int8_shape(const shape& s) const
{ {
if(s.type() != shape::int8_type) if(s.type() != shape::int8_type)
{ {
......
#include <migraphx/gpu/quant_gemm.hpp> #include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/device/pack.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <fstream> #include <fstream>
...@@ -54,43 +53,46 @@ rb_type<T>* to_rocblas_type(T* x) ...@@ -54,43 +53,46 @@ rb_type<T>* to_rocblas_type(T* x)
return reinterpret_cast<rb_type<T>*>(x); return reinterpret_cast<rb_type<T>*>(x);
} }
shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.erase(in_shapes.begin() + in_shapes.size() - 3, in_shapes.end()); in_shapes.pop_back();
check_shapes{in_shapes}.not_broadcasted(); check_shapes{in_shapes}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
return op.compute_shape(in_shapes); return op.compute_shape(in_shapes);
} }
argument miopen_quant_gemm::compute(context& ctx, void rocblas_quant_gemm::batch_not_transposed(const std::vector<std::size_t>& strides) const
const shape& output_shape, {
const std::vector<argument>& args) const if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("QUANT_DOT: batch size {" + to_string_range(strides) + "} is transposed!");
}
}
argument rocblas_quant_gemm::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1; auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2; auto dim_0 = n_dim - 2;
auto arg_num = args.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[arg_num - 1].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
if(!transb)
{
device::pack_a(ctx.get_stream().get(), args[arg_num - 2], args[1]);
}
// need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API
if(transa)
{
device::pack_b(ctx.get_stream().get(), args[arg_num - 3], args[0]);
}
device::sync_stream(ctx.get_stream().get());
bool is_3inputs = (arg_num == 6); bool is_3inputs = (args.size() == 4);
int32_t beta = 0; int32_t beta = 0;
if(is_3inputs) if(is_3inputs)
{ {
...@@ -124,18 +126,17 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -124,18 +126,17 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
(!transb) ? to_pointer(args[arg_num - 2]) to_pointer(args.at(1)),
: to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)), to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
to_pointer(args[arg_num - 1]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
...@@ -155,11 +156,11 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -155,11 +156,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
(!transb) ? to_pointer(args[arg_num - 2]) : to_pointer(args.at(1)), to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
k * n, k * n,
transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)), to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
m * k, m * k,
...@@ -168,7 +169,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -168,7 +169,7 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
m * n, m * n,
to_pointer(args[arg_num - 1]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
m * n, m * n,
...@@ -182,7 +183,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -182,7 +183,7 @@ argument miopen_quant_gemm::compute(context& ctx,
} }
}); });
return args[arg_num - 1]; return is_3inputs ? args[3] : args[2];
} }
} // namespace gpu } // namespace gpu
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/schedule_model.hpp> #include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/adjust_allocation.hpp> #include <migraphx/gpu/adjust_allocation.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/schedule.hpp> #include <migraphx/schedule.hpp>
...@@ -62,6 +63,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const ...@@ -62,6 +63,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
dead_code_elimination{}, dead_code_elimination{},
adjust_allocation{}, adjust_allocation{},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{},
dead_code_elimination{},
fuse_ops{&ctx}, fuse_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
...@@ -663,13 +663,13 @@ struct test_softmax2 : verify_program<test_softmax2> ...@@ -663,13 +663,13 @@ struct test_softmax2 : verify_program<test_softmax2>
} }
}; };
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_softmax : verify_program<test_softmax<Axis>> struct test_softmax : verify_program<test_softmax<Axis, T>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}}; migraphx::shape s{T, {512, 4, 1067, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::softmax{Axis}, param); p.add_instruction(migraphx::op::softmax{Axis}, param);
...@@ -677,10 +677,38 @@ struct test_softmax : verify_program<test_softmax<Axis>> ...@@ -677,10 +677,38 @@ struct test_softmax : verify_program<test_softmax<Axis>>
} }
}; };
template struct test_softmax<0>; template struct test_softmax<0, migraphx::shape::float_type>;
template struct test_softmax<1>; template struct test_softmax<2, migraphx::shape::float_type>;
template struct test_softmax<2>; template struct test_softmax<1, migraphx::shape::double_type>;
template struct test_softmax<3>; template struct test_softmax<3, migraphx::shape::double_type>;
template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>;
template <class T, int Axis>
struct test_arg_ops : verify_program<test_arg_ops<T, Axis>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 1025}};
auto param = p.add_parameter("data", s);
p.add_instruction(T{Axis}, param);
return p;
}
};
template struct test_arg_ops<migraphx::op::argmax, 0>;
template struct test_arg_ops<migraphx::op::argmax, 1>;
template struct test_arg_ops<migraphx::op::argmax, 2>;
template struct test_arg_ops<migraphx::op::argmax, 3>;
template struct test_arg_ops<migraphx::op::argmin, 0>;
template struct test_arg_ops<migraphx::op::argmin, 1>;
template struct test_arg_ops<migraphx::op::argmin, 2>;
template struct test_arg_ops<migraphx::op::argmin, 3>;
struct test_conv : verify_program<test_conv> struct test_conv : verify_program<test_conv>
{ {
...@@ -3601,32 +3629,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul ...@@ -3601,32 +3629,13 @@ struct test_lstm_bidirct_default_actv2 : verify_program<test_lstm_bidirct_defaul
} }
}; };
template <int Axis> template <int Axis, migraphx::shape::type_t T>
struct test_logsoftmax : verify_program<test_logsoftmax<Axis>> struct test_logsoftmax : verify_program<test_logsoftmax<Axis, T>>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
template struct test_logsoftmax<0>;
template struct test_logsoftmax<1>;
template struct test_logsoftmax<2>;
template struct test_logsoftmax<3>;
template <int Axis>
struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}}; migraphx::shape s{T, {10, 4, 2080, 6}};
auto param = p.add_parameter("0", s); auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param); p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
...@@ -3634,7 +3643,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>> ...@@ -3634,7 +3643,16 @@ struct test_logsoftmax_1 : verify_program<test_logsoftmax_1<Axis>>
} }
}; };
template struct test_logsoftmax_1<0>; template struct test_logsoftmax<0, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::float_type>;
template struct test_logsoftmax<2, migraphx::shape::float_type>;
template struct test_logsoftmax<3, migraphx::shape::float_type>;
template struct test_logsoftmax<1, migraphx::shape::double_type>;
template struct test_logsoftmax<3, migraphx::shape::double_type>;
template struct test_logsoftmax<1, migraphx::shape::half_type>;
template struct test_logsoftmax<0, migraphx::shape::half_type>;
template struct test_logsoftmax<2, migraphx::shape::half_type>;
template struct test_logsoftmax<3, migraphx::shape::half_type>;
struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall> struct test_fp32_fp16_lall : verify_program<test_fp32_fp16_lall>
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment