Commit 8143e4fb authored by wsttiger's avatar wsttiger
Browse files

Merge branch 'master' into remove_concat

parents 0a4583b7 9ca0fbf1
add_library(migraph_cpu add_library(migraph_cpu
cpu_target.cpp target.cpp
cpu_lowering.cpp lowering.cpp
gemm.cpp gemm.cpp
) )
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace migraph { namespace migraph {
namespace cpu { namespace cpu {
struct cpu_lowering struct lowering
{ {
std::string name() const { return "cpu::lowering"; } std::string name() const { return "cpu::lowering"; }
void apply(program& p) const; void apply(program& p) const;
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
namespace migraph { namespace migraph {
namespace cpu { namespace cpu {
struct cpu_target struct target
{ {
std::string name() const; std::string name() const;
std::vector<pass> get_passes(migraph::context& ctx) const; std::vector<pass> get_passes(migraph::context& ctx) const;
......
#include <migraph/cpu/cpu_lowering.hpp> #include <migraph/cpu/lowering.hpp>
#include <migraph/instruction.hpp> #include <migraph/instruction.hpp>
#include <migraph/dfor.hpp> #include <migraph/dfor.hpp>
#include <migraph/operators.hpp> #include <migraph/operators.hpp>
...@@ -312,8 +312,8 @@ struct cpu_concat ...@@ -312,8 +312,8 @@ struct cpu_concat
struct cpu_gemm struct cpu_gemm
{ {
op::gemm op; op::dot op;
std::string name() const { return "cpu::gemm"; } std::string name() const { return "cpu::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const shape& output_shape, std::vector<argument> args) const
...@@ -437,7 +437,7 @@ struct relu_op ...@@ -437,7 +437,7 @@ struct relu_op
std::string name() const { return "cpu::relu"; } std::string name() const { return "cpu::relu"; }
auto fcn() const auto fcn() const
{ {
return [](auto x) { return x > 0 ? x : 0; }; return [](auto x) { return std::max(decltype(x){0}, x); };
} }
}; };
...@@ -592,7 +592,7 @@ struct cpu_apply ...@@ -592,7 +592,7 @@ struct cpu_apply
{ {
apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>(); apply_map["im2col"] = extend_op<cpu_im2col, op::im2col>();
apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>(); apply_map["convolution"] = extend_op<cpu_convolution, op::convolution>();
apply_map["gemm"] = extend_op<cpu_gemm, op::gemm>(); apply_map["dot"] = extend_op<cpu_gemm, op::dot>();
apply_map["batch_norm_inference"] = apply_map["batch_norm_inference"] =
extend_op<cpu_batch_norm_inference, op::batch_norm_inference>(); extend_op<cpu_batch_norm_inference, op::batch_norm_inference>();
apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>(); apply_map["contiguous"] = extend_op<cpu_contiguous, op::contiguous>();
...@@ -664,7 +664,7 @@ struct cpu_apply ...@@ -664,7 +664,7 @@ struct cpu_apply
} }
}; };
void cpu_lowering::apply(program& p) const { cpu_apply{&p}.apply(); } void lowering::apply(program& p) const { cpu_apply{&p}.apply(); }
} // namespace cpu } // namespace cpu
......
#include <migraph/cpu/cpu_target.hpp> #include <migraph/cpu/target.hpp>
#include <migraph/cpu/cpu_lowering.hpp> #include <migraph/cpu/lowering.hpp>
#include <migraph/auto_contiguous.hpp> #include <migraph/auto_contiguous.hpp>
namespace migraph { namespace migraph {
namespace cpu { namespace cpu {
std::string cpu_target::name() const { return "cpu"; } std::string target::name() const { return "cpu"; }
std::vector<pass> cpu_target::get_passes(migraph::context&) const std::vector<pass> target::get_passes(migraph::context&) const
{ {
return {auto_contiguous{}, cpu_lowering{}}; return {auto_contiguous{}, lowering{}};
} }
} // namespace cpu } // namespace cpu
......
...@@ -10,10 +10,11 @@ if(NOT TARGET MIOpen) ...@@ -10,10 +10,11 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
add_library(migraph_device add_library(migraph_device
device/add.cpp device/add.cpp
device/add_relu.cpp device/add_relu.cpp
device/contiguous.cpp device/contiguous.cpp
device/mul.cpp
device/concat.cpp device/concat.cpp
) )
rocm_clang_tidy_check(migraph_device) rocm_clang_tidy_check(migraph_device)
...@@ -36,6 +37,7 @@ add_library(migraph_gpu ...@@ -36,6 +37,7 @@ add_library(migraph_gpu
relu.cpp relu.cpp
leaky_relu.cpp leaky_relu.cpp
add.cpp add.cpp
mul.cpp
batchnorm.cpp batchnorm.cpp
write_literals.cpp write_literals.cpp
rocblas.cpp rocblas.cpp
......
...@@ -14,9 +14,9 @@ shape hip_add::compute_shape(const std::vector<shape>& inputs) const ...@@ -14,9 +14,9 @@ shape hip_add::compute_shape(const std::vector<shape>& inputs) const
return inputs.at(0); return inputs.at(0);
} }
argument hip_add::compute(context&, const shape&, const std::vector<argument>& args) const argument hip_add::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::add(args[2], args[0], args[1]); device::add(ctx.get_stream().get(), args[2], args[0], args[1]);
return args[2]; return args[2];
} }
...@@ -34,7 +34,7 @@ argument miopen_add::compute(context& ctx, ...@@ -34,7 +34,7 @@ argument miopen_add::compute(context& ctx,
auto a_desc = make_tensor(args[0].get_shape()); auto a_desc = make_tensor(args[0].get_shape());
auto b_desc = make_tensor(args[1].get_shape()); auto b_desc = make_tensor(args[1].get_shape());
auto c_desc = make_tensor(output_shape); auto c_desc = make_tensor(output_shape);
miopenOpTensor(ctx.handle.get(), miopenOpTensor(ctx.get_stream().get_miopen(),
miopenTensorOpAdd, miopenTensorOpAdd,
&alpha, &alpha,
a_desc.get(), a_desc.get(),
......
...@@ -23,7 +23,7 @@ argument miopen_batch_norm_inference::compute(context& ctx, ...@@ -23,7 +23,7 @@ argument miopen_batch_norm_inference::compute(context& ctx,
float alpha = 1.0, beta = 0.0f; float alpha = 1.0, beta = 0.0f;
miopenBatchNormalizationForwardInference(ctx.handle.get(), miopenBatchNormalizationForwardInference(ctx.get_stream().get_miopen(),
miopenBatchNormMode_t(op.bn_mode), miopenBatchNormMode_t(op.bn_mode),
&alpha, &alpha,
&beta, &beta,
......
...@@ -14,11 +14,12 @@ shape hip_concat::compute_shape(std::vector<shape> inputs) const ...@@ -14,11 +14,12 @@ shape hip_concat::compute_shape(std::vector<shape> inputs) const
return op.compute_shape(inputs); return op.compute_shape(inputs);
} }
argument argument hip_concat::compute(context& ctx,
hip_concat::compute(context&, const shape& output_shape, const std::vector<argument>& args) const const shape& output_shape,
const std::vector<argument>& args) const
{ {
std::vector<std::size_t> offsets = op.compute_offsets(output_shape, args); std::vector<std::size_t> offsets = op.compute_offsets(output_shape, args);
return device::concat(output_shape, args, offsets); return device::concat(ctx.get_stream().get(), output_shape, args, offsets);
} }
} // namespace gpu } // namespace gpu
......
...@@ -12,13 +12,14 @@ shape miopen_contiguous::compute_shape(const std::vector<shape>& inputs) const ...@@ -12,13 +12,14 @@ shape miopen_contiguous::compute_shape(const std::vector<shape>& inputs) const
check_shapes{inputs, *this}.has(2); check_shapes{inputs, *this}.has(2);
return op.compute_shape({inputs.at(0)}); return op.compute_shape({inputs.at(0)});
} }
argument argument miopen_contiguous::compute(context& ctx,
miopen_contiguous::compute(context&, shape output_shape, const std::vector<argument>& args) const shape output_shape,
const std::vector<argument>& args) const
{ {
assert(output_shape == args[1].get_shape()); assert(output_shape == args[1].get_shape());
assert(output_shape.standard()); assert(output_shape.standard());
(void)output_shape; (void)output_shape;
device::contiguous(args.at(1), args.at(0)); device::contiguous(ctx.get_stream().get(), args.at(1), args.at(0));
return args.at(1); return args.at(1);
} }
......
...@@ -21,7 +21,7 @@ argument miopen_convolution::compute(context& ctx, ...@@ -21,7 +21,7 @@ argument miopen_convolution::compute(context& ctx,
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
miopenConvolutionForward(ctx.handle.get(), miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[0].implicit(), args[0].implicit(),
...@@ -47,18 +47,22 @@ shape miopen_convolution::compile(context& ctx, ...@@ -47,18 +47,22 @@ shape miopen_convolution::compile(context& ctx,
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
std::size_t workspace_size = 0; std::size_t workspace_size = 0;
miopenConvolutionForwardGetWorkSpaceSize( miopenConvolutionForwardGetWorkSpaceSize(ctx.get_stream().get_miopen(),
ctx.handle.get(), w_desc.get(), x_desc.get(), cd.get(), y_desc.get(), &workspace_size); w_desc.get(),
x_desc.get(),
cd.get(),
y_desc.get(),
&workspace_size);
workspace_shape = shape{shape::int8_type, {workspace_size}}; workspace_shape = shape{shape::int8_type, {workspace_size}};
auto x = to_gpu(generate_argument(inputs[0]->get_shape())); auto x = to_gpu(generate_argument(inputs[0]->get_shape()));
auto w = to_gpu(generate_argument(inputs[1]->get_shape())); auto w = to_gpu(generate_argument(inputs[1]->get_shape()));
auto y = to_gpu(generate_argument(output_shape)); auto y = allocate_gpu(output_shape);
auto workspace = allocate_gpu(workspace_shape); auto workspace = allocate_gpu(workspace_shape);
int algo_count = 1; int algo_count = 1;
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(ctx.handle.get(), miopenFindConvolutionForwardAlgorithm(ctx.get_stream().get_miopen(),
x_desc.get(), x_desc.get(),
x.implicit(), x.implicit(),
w_desc.get(), w_desc.get(),
......
...@@ -5,14 +5,18 @@ namespace migraph { ...@@ -5,14 +5,18 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void add(const argument& result, const argument& arg1, const argument& arg2) void add(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
nary(result, arg1, arg2)([](auto x, auto y) { return x + y; }); nary(stream, result, arg1, arg2)([](auto x, auto y) { return x + y; });
} }
void add(const argument& result, const argument& arg1, const argument& arg2, const argument& arg3) void add(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
nary(result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x + y + z; }); nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x + y + z; });
} }
} // namespace device } // namespace device
......
...@@ -5,17 +5,22 @@ namespace migraph { ...@@ -5,17 +5,22 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void add_relu(const argument& result, const argument& arg1, const argument& arg2) void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2)
{ {
nary(result, arg1, arg2)([](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); }); nary(stream, result, arg1, arg2)(
[](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); });
} }
void add_relu(const argument& result, void add_relu(hipStream_t stream,
const argument& result,
const argument& arg1, const argument& arg1,
const argument& arg2, const argument& arg2,
const argument& arg3) const argument& arg3)
{ {
nary(result, arg1, arg2, arg3)( nary(stream, result, arg1, arg2, arg3)(
[](auto x, auto y, auto z) { return std::max<decltype(x + y + z)>(0, x + y + z); }); [](auto x, auto y, auto z) { return std::max<decltype(x + y + z)>(0, x + y + z); });
} }
......
...@@ -8,11 +8,11 @@ namespace migraph { ...@@ -8,11 +8,11 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
argument concat(const migraph::shape& output_shape, argument concat(hipStream_t stream,
const migraph::shape& output_shape,
std::vector<migraph::argument> args, std::vector<migraph::argument> args,
std::vector<std::size_t> offsets) std::vector<std::size_t> offsets)
{ {
// migraph::argument& result = args.back();
for(std::size_t l = 0; l < args.size() - 1; l++) for(std::size_t l = 0; l < args.size() - 1; l++)
{ {
auto argl = args[l]; auto argl = args[l];
...@@ -23,12 +23,11 @@ argument concat(const migraph::shape& output_shape, ...@@ -23,12 +23,11 @@ argument concat(const migraph::shape& output_shape,
const auto* inptr = input.data(); const auto* inptr = input.data();
hip_tensor_descriptor<ndim> desc_input(input.get_shape()); hip_tensor_descriptor<ndim> desc_input(input.get_shape());
hip_tensor_descriptor<ndim> desc_output(output.get_shape()); hip_tensor_descriptor<ndim> desc_output(output.get_shape());
gs_launch(nelements)( gs_launch(stream, nelements)(
[=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; }); [=](auto i) { outptr[desc_output.linear(desc_input.multi(i))] = inptr[i]; });
}); });
}); });
} }
// return result;
return args.back(); return args.back();
} }
......
...@@ -6,9 +6,9 @@ namespace migraph { ...@@ -6,9 +6,9 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void contiguous(argument result, argument arg) void contiguous(hipStream_t stream, argument result, argument arg)
{ {
nary_nonstandard(std::move(result), std::move(arg))([](auto x) { return x; }); nary_nonstandard(stream, std::move(result), std::move(arg))([](auto x) { return x; });
} }
} // namespace device } // namespace device
......
...@@ -21,7 +21,7 @@ __global__ void launcher(F f) ...@@ -21,7 +21,7 @@ __global__ void launcher(F f)
f(idx); f(idx);
} }
inline auto launch(std::size_t global, std::size_t local) inline auto launch(hipStream_t stream, std::size_t global, std::size_t local)
{ {
return [=](auto f) { return [=](auto f) {
assert(local > 0); assert(local > 0);
...@@ -29,17 +29,17 @@ inline auto launch(std::size_t global, std::size_t local) ...@@ -29,17 +29,17 @@ inline auto launch(std::size_t global, std::size_t local)
using f_type = decltype(f); using f_type = decltype(f);
dim3 nblocks(global / local); dim3 nblocks(global / local);
dim3 nthreads(local); dim3 nthreads(local);
hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, nullptr, f); hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f);
}; };
} }
inline auto gs_launch(std::size_t n, std::size_t local = 1024) inline auto gs_launch(hipStream_t stream, std::size_t n, std::size_t local = 1024)
{ {
std::size_t groups = 1 + n / local; std::size_t groups = 1 + n / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local; std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(nglobal, local)([=](auto idx) { launch(stream, nglobal, local)([=](auto idx) {
for(size_t i = idx.global; i < n; i += nglobal) for(size_t i = idx.global; i < n; i += nglobal)
{ {
f(i); f(i);
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <migraph/gpu/device/tensor.hpp> #include <migraph/gpu/device/tensor.hpp>
#include <migraph/gpu/device/launch.hpp> #include <migraph/gpu/device/launch.hpp>
#include <migraph/gpu/device/types.hpp>
#include <migraph/functional.hpp> #include <migraph/functional.hpp>
#include <migraph/ranges.hpp> #include <migraph/ranges.hpp>
...@@ -32,16 +33,16 @@ auto pack_vec4(Ts... xs) ...@@ -32,16 +33,16 @@ auto pack_vec4(Ts... xs)
} }
template <class F, class... Arguments> template <class F, class... Arguments>
auto nary_nonstandard_impl(F f, argument result, Arguments... args) auto nary_nonstandard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = pack( auto data = pack(std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()},
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, inputs.data())...); device_cast(inputs.data()))...);
hip_tensor_descriptor<ndim> out_desc(output_shape); hip_tensor_descriptor<ndim> out_desc(output_shape);
auto* outp = output.data(); auto* outp = device_cast(output.data());
gs_launch(output_shape.elements())([=](auto i) { gs_launch(stream, output_shape.elements())([=](auto i) {
data([&](auto&&... ps) { data([&](auto&&... ps) {
auto outidx = out_desc.multi(i); auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...); outp[i] = f(ps.second[ps.first.linear(outidx)]...);
...@@ -52,8 +53,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args) ...@@ -52,8 +53,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
} }
template <class F> template <class F>
void trinary_broadcast_vec_impl( void trinary_broadcast_vec_impl(hipStream_t stream,
F f, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3) F f,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape(); const auto& b_shape = arg3.get_shape();
...@@ -67,11 +72,11 @@ void trinary_broadcast_vec_impl( ...@@ -67,11 +72,11 @@ void trinary_broadcast_vec_impl(
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) { visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(input1.data()); auto* xp = as_vec4(device_cast(input1.data()));
auto* yp = as_vec4(input2.data()); auto* yp = as_vec4(device_cast(input2.data()));
auto* zp = as_vec4(input3.data()); auto* zp = as_vec4(device_cast(input3.data()));
auto* outp = as_vec4(output.data()); auto* outp = as_vec4(device_cast(output.data()));
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
...@@ -79,7 +84,7 @@ void trinary_broadcast_vec_impl( ...@@ -79,7 +84,7 @@ void trinary_broadcast_vec_impl(
const std::size_t n = output.size() / vec_size; const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size; const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED vec4<type> buffer[2048 / vec_size]; MIGRAPH_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
...@@ -107,8 +112,12 @@ void trinary_broadcast_vec_impl( ...@@ -107,8 +112,12 @@ void trinary_broadcast_vec_impl(
} }
template <class F> template <class F>
void trinary_broadcast_impl( void trinary_broadcast_impl(hipStream_t stream,
F f, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3) F f,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg3.get_shape(); const auto& b_shape = arg3.get_shape();
...@@ -122,17 +131,17 @@ void trinary_broadcast_impl( ...@@ -122,17 +131,17 @@ void trinary_broadcast_impl(
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) { visit_all(result, arg1, arg2, arg3)([&](auto output, auto input1, auto input2, auto input3) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = input1.data(); auto* xp = device_cast(input1.data());
auto* yp = input2.data(); auto* yp = device_cast(input2.data());
auto* zp = input3.data(); auto* zp = device_cast(input3.data());
auto* outp = output.data(); auto* outp = device_cast(output.data());
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size(); const std::size_t n = output.size();
launch(nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED type buffer[2048]; MIGRAPH_DEVICE_SHARED type buffer[2048];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal) for(size_t i = idx.local; i < bdim_len; i += nlocal)
...@@ -154,10 +163,8 @@ void trinary_broadcast_impl( ...@@ -154,10 +163,8 @@ void trinary_broadcast_impl(
} }
template <class F> template <class F>
void binary_broadcast_vec_impl(F f, void binary_broadcast_vec_impl(
const argument& result, hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2)
const argument& arg1,
const argument& arg2)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
...@@ -171,10 +178,10 @@ void binary_broadcast_vec_impl(F f, ...@@ -171,10 +178,10 @@ void binary_broadcast_vec_impl(F f,
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = as_vec4(input1.data()); auto* xp = as_vec4(device_cast(input1.data()));
auto* yp = as_vec4(input2.data()); auto* yp = as_vec4(device_cast(input2.data()));
auto* outp = as_vec4(output.data()); auto* outp = as_vec4(device_cast(output.data()));
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
...@@ -182,7 +189,7 @@ void binary_broadcast_vec_impl(F f, ...@@ -182,7 +189,7 @@ void binary_broadcast_vec_impl(F f,
const std::size_t n = output.size() / vec_size; const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size; const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED vec4<type> buffer[2048 / vec_size]; MIGRAPH_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal) for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
...@@ -209,7 +216,8 @@ void binary_broadcast_vec_impl(F f, ...@@ -209,7 +216,8 @@ void binary_broadcast_vec_impl(F f,
} }
template <class F> template <class F>
void binary_broadcast_impl(F f, const argument& result, const argument& arg1, const argument& arg2) void binary_broadcast_impl(
hipStream_t stream, F f, const argument& result, const argument& arg1, const argument& arg2)
{ {
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape(); const auto& b_shape = arg2.get_shape();
...@@ -223,16 +231,16 @@ void binary_broadcast_impl(F f, const argument& result, const argument& arg1, co ...@@ -223,16 +231,16 @@ void binary_broadcast_impl(F f, const argument& result, const argument& arg1, co
auto bdim_next_stride = bdim_stride * bdim_len; auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) { visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
auto* xp = input1.data(); auto* xp = device_cast(input1.data());
auto* yp = input2.data(); auto* yp = device_cast(input2.data());
auto* outp = output.data(); auto* outp = device_cast(output.data());
const std::size_t nlocal = 1024; const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal; const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size(); const std::size_t n = output.size();
launch(nglobal, nlocal)([=](auto idx) __device__ { launch(stream, nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED type buffer[2048]; MIGRAPH_DEVICE_SHARED type buffer[2048];
// Load bias into LDS // Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal) for(size_t i = idx.local; i < bdim_len; i += nlocal)
...@@ -253,16 +261,16 @@ void binary_broadcast_impl(F f, const argument& result, const argument& arg1, co ...@@ -253,16 +261,16 @@ void binary_broadcast_impl(F f, const argument& result, const argument& arg1, co
} }
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_vec_impl(F f, argument result, Arguments... args) void nary_standard_vec_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
using type = std::remove_cv_t<typename decltype(output)::value_type>; using type = device_type<std::remove_cv_t<typename decltype(output)::value_type>>;
const std::size_t vec_size = 4; const std::size_t vec_size = 4;
auto data = pack_vec4(inputs.data()...); auto data = pack_vec4(device_cast(inputs.data())...);
auto* outp = as_vec4(output.data()); auto* outp = as_vec4(device_cast(output.data()));
gs_launch(output_shape.elements() / vec_size)([=](auto i) { gs_launch(stream, output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i]; vec4<type> out = outp[i];
data( data(
[&](auto... xs) { [&](auto... xs) {
...@@ -278,54 +286,56 @@ void nary_standard_vec_impl(F f, argument result, Arguments... args) ...@@ -278,54 +286,56 @@ void nary_standard_vec_impl(F f, argument result, Arguments... args)
} }
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_standard_impl(F f, argument result, Arguments... args) void nary_standard_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
// assert(x.get_shape().elements() == y.get_shape().elements()); // assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = pack(inputs.data()...); auto data = pack(device_cast(inputs.data())...);
auto* outp = output.data(); auto* outp = device_cast(output.data());
gs_launch(output_shape.elements())( gs_launch(stream, output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); }); [=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
}); });
} }
template <class F, class... Arguments> template <class F, class... Arguments>
void nary_impl(F f, argument result, Arguments... args) void nary_impl(hipStream_t stream, F f, argument result, Arguments... args)
{ {
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); }); bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); }); bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes = bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); }); all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes)) if(standard or (packed and same_shapes))
nary_standard_impl(f, result, args...); nary_standard_impl(stream, f, result, args...);
else else
nary_nonstandard_impl(f, result, args...); nary_nonstandard_impl(stream, f, result, args...);
} }
template <class... Arguments> template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args) auto nary_nonstandard(hipStream_t stream, argument result, Arguments... args)
{ {
return [=](auto f) { nary_nonstandard_impl(f, result, args...); }; return [=](auto f) { nary_nonstandard_impl(stream, f, result, args...); };
} }
template <class... Arguments> template <class... Arguments>
auto nary_standard(argument result, Arguments... args) auto nary_standard(hipStream_t stream, argument result, Arguments... args)
{ {
return [=](auto f) { nary_standard_impl(f, result, args...); }; return [=](auto f) { nary_standard_impl(stream, f, result, args...); };
} }
template <class... Arguments> template <class... Arguments>
auto nary(argument result, Arguments... args) auto nary(hipStream_t stream, argument result, Arguments... args)
{ {
return [=](auto f) { nary_impl(f, result, args...); }; return [=](auto f) { nary_impl(stream, f, result, args...); };
} }
inline auto nary(const argument& result, const argument& arg1, const argument& arg2) inline auto
nary(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{ {
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same // TODO: Check result and arg1 shape is the same
if(arg1.get_shape().standard() and arg2.get_shape().broadcasted()) if(arg1.get_shape().standard() and arg2.get_shape().broadcasted() and
not arg2.get_shape().scalar())
{ {
auto not_zero = [](auto x) { return x != 0; }; auto not_zero = [](auto x) { return x != 0; };
const auto& strides = arg2.get_shape().strides(); const auto& strides = arg2.get_shape().strides();
...@@ -339,18 +349,21 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a ...@@ -339,18 +349,21 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0); (arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4) if(divisible_by_4)
binary_broadcast_vec_impl(f, result, arg1, arg2); binary_broadcast_vec_impl(stream, f, result, arg1, arg2);
else else
binary_broadcast_impl(f, result, arg1, arg2); binary_broadcast_impl(stream, f, result, arg1, arg2);
return; return;
} }
} }
nary_impl(f, result, arg1, arg2); nary_impl(stream, f, result, arg1, arg2);
}; };
} }
inline auto inline auto nary(hipStream_t stream,
nary(const argument& result, const argument& arg1, const argument& arg2, const argument& arg3) const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{ {
return [=](auto f) { return [=](auto f) {
// TODO: Check result and arg1 shape is the same // TODO: Check result and arg1 shape is the same
...@@ -369,13 +382,13 @@ nary(const argument& result, const argument& arg1, const argument& arg2, const a ...@@ -369,13 +382,13 @@ nary(const argument& result, const argument& arg1, const argument& arg2, const a
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0); (arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4) if(divisible_by_4)
trinary_broadcast_vec_impl(f, result, arg1, arg2, arg3); trinary_broadcast_vec_impl(stream, f, result, arg1, arg2, arg3);
else else
trinary_broadcast_impl(f, result, arg1, arg2, arg3); trinary_broadcast_impl(stream, f, result, arg1, arg2, arg3);
return; return;
} }
} }
nary_impl(f, result, arg1, arg2, arg3); nary_impl(stream, f, result, arg1, arg2, arg3);
}; };
} }
......
/*=============================================================================
Copyright (c) 2017 Paul Fultz II
types.hpp
Distributed under the Boost Software License, Version 1.0. (See accompanying
file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
==============================================================================*/
#ifndef MIGRAPH_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#define MIGRAPH_GUARD_RTGLIB_GPU_DEVICE_TYPES_HPP
#include <migraph/half.hpp>
namespace migraph {
namespace gpu {
namespace device {
using gpu_half = __fp16;
namespace detail {
template <class T>
struct device_type
{
using type = T;
};
template <>
struct device_type<half>
{
using type = gpu_half;
};
template <class T>
struct host_type
{
using type = T;
};
template <>
struct device_type<gpu_half>
{
using type = half;
};
} // namespace detail
template <class T>
using host_type = typename detail::host_type<T>::type;
template <class T>
using device_type = typename detail::device_type<T>::type;
template <class T>
host_type<T> host_cast(T x)
{
return reinterpret_cast<host_type<T>>(x);
}
template <class T>
host_type<T>* host_cast(T* x)
{
return reinterpret_cast<host_type<T>*>(x);
}
template <class T>
device_type<T> device_cast(T x)
{
return reinterpret_cast<device_type<T>>(x);
}
template <class T>
device_type<T>* device_cast(T* x)
{
return reinterpret_cast<device_type<T>*>(x);
}
} // namespace device
} // namespace gpu
} // namespace migraph
#endif
#include <migraph/gpu/device/mul.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace migraph {
namespace gpu {
namespace device {
void mul(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2)
{
nary(stream, result, arg1, arg2)([](auto x, auto y) { return x * y; });
}
void mul(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3)
{
nary(stream, result, arg1, arg2, arg3)([](auto x, auto y, auto z) { return x * y * z; });
}
} // namespace device
} // namespace gpu
} // namespace migraph
...@@ -82,13 +82,14 @@ struct fusion ...@@ -82,13 +82,14 @@ struct fusion
// int algo_count = 1; // int algo_count = 1;
// miopenConvFwdAlgorithm_t algo; // miopenConvFwdAlgorithm_t algo;
// miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo); // miopenFusionPlanConvolutionGetAlgo(fp.get(), 1, &algo_count, &algo);
// miopenFusionPlanGetWorkSpaceSize(ctx.handle.get(), fp.get(), &ws_size, algo); // miopenFusionPlanGetWorkSpaceSize(ctx.get_stream().get_miopen(), fp.get(), &ws_size,
// algo);
return shape{shape::int8_type, {ws_size}}; return shape{shape::int8_type, {ws_size}};
} }
void compile(context& ctx) void compile(context& ctx)
{ {
auto status = miopenCompileFusionPlan(ctx.handle.get(), fp.get()); auto status = miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPH_THROW("Compiling fusion plan failed"); MIGRAPH_THROW("Compiling fusion plan failed");
} }
...@@ -100,7 +101,7 @@ struct fusion ...@@ -100,7 +101,7 @@ struct fusion
{ {
auto x_td = make_tensor(x.get_shape()); auto x_td = make_tensor(x.get_shape());
auto y_td = make_tensor(y.get_shape()); auto y_td = make_tensor(y.get_shape());
auto status = miopenExecuteFusionPlan(ctx.handle.get(), auto status = miopenExecuteFusionPlan(ctx.get_stream().get_miopen(),
fp.get(), fp.get(),
x_td.get(), x_td.get(),
x.implicit(), x.implicit(),
...@@ -133,15 +134,12 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -133,15 +134,12 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
return false; return false;
auto wei = ins->inputs().at(1)->get_shape(); auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4); assert(wei.lens().size() == 4);
auto channels = wei.lens()[1] * wei.lens()[0];
if(wei.lens()[0] > 64 and channels > 32768)
return false;
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.algo == miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
return false; return false;
auto op = conv.op; auto op = conv.op;
return op.padding == make_array<size_t>(0, 0) and op.stride == make_array<size_t>(1, 1) and return contains({{0, 0}, {1, 1}, {2, 2}}, op.padding) and
op.dilation == make_array<size_t>(1, 1); contains({{0, 0}, {1, 1}}, op.stride) and op.dilation == make_array<size_t>(1, 1);
} }
struct hip_triadd struct hip_triadd
...@@ -152,9 +150,9 @@ struct hip_triadd ...@@ -152,9 +150,9 @@ struct hip_triadd
check_shapes{inputs, *this}.has(4); check_shapes{inputs, *this}.has(4);
return inputs.front(); return inputs.front();
} }
argument compute(context&, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::add(args.at(3), args.at(0), args.at(1), args.at(2)); device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3); return args.at(3);
} }
}; };
...@@ -167,9 +165,9 @@ struct hip_triadd_relu ...@@ -167,9 +165,9 @@ struct hip_triadd_relu
check_shapes{inputs, *this}.has(4); check_shapes{inputs, *this}.has(4);
return inputs.front(); return inputs.front();
} }
argument compute(context&, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::add_relu(args.at(3), args.at(0), args.at(1), args.at(2)); device::add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3); return args.at(3);
} }
}; };
...@@ -182,9 +180,9 @@ struct hip_add_relu ...@@ -182,9 +180,9 @@ struct hip_add_relu
check_shapes{inputs, *this}.has(3); check_shapes{inputs, *this}.has(3);
return inputs.front(); return inputs.front();
} }
argument compute(context&, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::add_relu(args.at(2), args.at(0), args.at(1)); device::add_relu(ctx.get_stream().get(), args.at(2), args.at(0), args.at(1));
return args.at(2); return args.at(2);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment