Commit 4c031df7 authored by wsttiger's avatar wsttiger
Browse files

Fixed conflicts

parents d32653a5 ed5f9897
......@@ -9,6 +9,7 @@
#include <migraph/pass_config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
void eliminate_workspace::apply(program& p) const
......@@ -38,5 +39,7 @@ void eliminate_workspace::apply(program& p) const
p.remove_instruction(a);
}
}
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
......@@ -6,7 +6,7 @@
#include <migraph/instruction.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
struct fusion
......@@ -132,6 +132,8 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
{
if(ins->name() != "gpu::convolution")
return false;
if(ins->get_shape().type() != shape::float_type)
return false;
auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4);
auto conv = any_cast<miopen_convolution>(ins->get_operator());
......@@ -155,6 +157,7 @@ struct hip_triadd
device::add(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
struct hip_triadd_relu
......@@ -170,6 +173,7 @@ struct hip_triadd_relu
device::add_relu(ctx.get_stream().get(), args.at(3), args.at(0), args.at(1), args.at(2));
return args.at(3);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
struct hip_add_relu
......@@ -185,6 +189,7 @@ struct hip_add_relu
device::add_relu(ctx.get_stream().get(), args.at(2), args.at(0), args.at(1));
return args.at(2);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
struct find_add_relu
......@@ -271,6 +276,7 @@ struct miopen_conv_bias
f.compile(ctx);
return f.get_workspace(ctx);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
struct miopen_conv_bias_relu
......@@ -314,6 +320,7 @@ struct miopen_conv_bias_relu
f.compile(ctx);
return f.get_workspace(ctx);
}
int output_alias(const std::vector<shape>& shapes) const { return shapes.size() - 1; }
};
template <class... Ms>
......@@ -382,5 +389,5 @@ void fuse_ops::apply(program& p) const
}
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
......@@ -5,8 +5,68 @@
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
template <class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
rocblas_sgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
rocblas_dgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
rocblas_hgemm(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
MIGRAPH_THROW("Type unsupported by rocblas");
}
template <class T>
struct compute_rocblas_type
{
using type = T;
};
template <class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template <>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template <class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
rb_type<T>* to_rocblas_type(T* x)
{
return reinterpret_cast<rb_type<T>*>(x);
}
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
......@@ -26,23 +86,30 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha,
args[1].implicit(),
ldb,
args[0].implicit(),
lda,
&beta,
args[2].implicit(),
ldc);
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args[1]),
ldb,
to_pointer(args[0]),
lda,
&beta_r,
to_pointer(args[2]),
ldc);
});
return args[2];
}
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
......@@ -7,6 +7,7 @@
#include <vector>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
using hip_ptr = MIGRAPH_MANAGE_PTR(void, hipFree);
......@@ -107,5 +108,7 @@ void copy_to_gpu(argument src, argument dst)
if(status != hipSuccess)
MIGRAPH_THROW("Copy to gpu failed: " + hip_error(status));
}
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
......@@ -15,9 +15,11 @@
#include <migraph/iterator_for.hpp>
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/context.hpp>
#include <migraph/config.hpp>
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
struct hip_add
......@@ -38,7 +40,7 @@ struct miopen_add
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -15,9 +15,11 @@
#include <migraph/iterator_for.hpp>
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/context.hpp>
#include <migraph/config.hpp>
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
struct miopen_batch_norm_inference
......@@ -31,7 +33,7 @@ struct miopen_batch_norm_inference
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -7,6 +7,7 @@
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/config.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/dfor.hpp>
......@@ -18,6 +19,7 @@
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
struct hip_concat
......@@ -32,7 +34,7 @@ struct hip_concat
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -5,8 +5,10 @@
#include <migraph/gpu/rocblas.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/env.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
MIGRAPH_DECLARE_ENV_VAR(MIGRAPH_DISABLE_NULL_STREAM)
......@@ -114,6 +116,7 @@ struct context
std::shared_ptr<hip_device> current_device;
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -7,6 +7,7 @@
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/config.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/dfor.hpp>
......@@ -18,6 +19,7 @@
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
struct miopen_contiguous
......@@ -30,7 +32,7 @@ struct miopen_contiguous
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -7,6 +7,7 @@
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/config.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/dfor.hpp>
......@@ -18,6 +19,7 @@
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
struct miopen_convolution
......@@ -42,7 +44,7 @@ struct miopen_convolution
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -3,9 +3,11 @@
#define MIGRAPH_GUARD_RTGLIB_DEVICE_ADD_HPP
#include <migraph/argument.hpp>
#include <migraph/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
namespace device {
......@@ -19,6 +21,7 @@ void add(hipStream_t stream,
} // namespace device
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -3,9 +3,11 @@
#define MIGRAPH_GUARD_RTGLIB_DEVICE_ADD_RELU_HPP
#include <migraph/argument.hpp>
#include <migraph/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
namespace device {
......@@ -22,6 +24,7 @@ void add_relu(hipStream_t stream,
} // namespace device
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -2,9 +2,11 @@
#define MIGRAPH_GUARD_RTGLIB_DEVICE_CONCAT_HPP
#include <migraph/argument.hpp>
#include <migraph/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
namespace device {
......@@ -15,6 +17,7 @@ argument concat(hipStream_t stream,
} // namespace device
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -2,9 +2,11 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_KERNELS_HPP
#include <migraph/argument.hpp>
#include <migraph/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
namespace device {
......@@ -12,6 +14,7 @@ void contiguous(hipStream_t stream, argument result, argument arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -3,9 +3,11 @@
#define MIGRAPH_GUARD_RTGLIB_DEVICE_MUL_HPP
#include <migraph/argument.hpp>
#include <migraph/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
namespace device {
......@@ -19,6 +21,7 @@ void mul(hipStream_t stream,
} // namespace device
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -3,8 +3,10 @@
#include <string>
#include <migraph/instruction_ref.hpp>
#include <migraph/config.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
struct program;
namespace gpu {
......@@ -15,6 +17,7 @@ struct eliminate_workspace
void apply(program& p) const;
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -2,9 +2,11 @@
#define MIGRAPH_GUARD_RTGLIB_FUSE_OPS_HPP
#include <migraph/program.hpp>
#include <migraph/config.hpp>
#include <migraph/gpu/context.hpp>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
......@@ -16,7 +18,7 @@ struct fuse_ops
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -10,6 +10,7 @@
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/dfor.hpp>
#include <migraph/config.hpp>
#include <migraph/gpu/device/contiguous.hpp>
#include <migraph/gpu/device/add.hpp>
#include <migraph/iterator_for.hpp>
......@@ -18,6 +19,7 @@
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
struct miopen_gemm
......@@ -31,7 +33,7 @@ struct miopen_gemm
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -2,9 +2,11 @@
#define MIGRAPH_GUARD_MIGRAPHLIB_HIP_HPP
#include <migraph/operators.hpp>
#include <migraph/config.hpp>
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
migraph::argument allocate_gpu(const migraph::shape& s, bool host = false);
......@@ -85,7 +87,9 @@ struct hip_copy
}
int output_alias(const std::vector<shape>&) const { return 1; }
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
......@@ -7,6 +7,7 @@
#include <migraph/operators.hpp>
#include <migraph/generate.hpp>
#include <migraph/shape_for_each.hpp>
#include <migraph/config.hpp>
#include <migraph/gpu/miopen.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/dfor.hpp>
......@@ -18,6 +19,7 @@
#include <utility>
namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
struct miopen_leaky_relu
......@@ -31,7 +33,7 @@ struct miopen_leaky_relu
};
} // namespace gpu
} // namespace MIGRAPH_INLINE_NS
} // namespace migraph
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment