Unverified Commit 734deb37 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'master' into transpose

parents 272bf6b1 63978fb4
......@@ -534,7 +534,7 @@ struct onnx_parser
case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()};
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::FLOAT16: return literal{{shape::half_type, dims}, s.data()};
case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()};
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
......@@ -562,7 +562,8 @@ struct onnx_parser
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL:
return literal{{shape::int32_type, dims}, t.int32_data().begin(), t.int32_data().end()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::FLOAT16:
return literal{{shape::half_type, dims}, t.float_data().begin(), t.float_data().end()};
case onnx::TensorProto::DOUBLE:
return literal{
{shape::double_type, dims}, t.double_data().begin(), t.double_data().end()};
......@@ -593,8 +594,7 @@ struct onnx_parser
break; // throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16:
break; // throw std::runtime_error("Unsupported type FLOAT16");
case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
......
......@@ -132,6 +132,8 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
{
if(ins->name() != "gpu::convolution")
return false;
if(ins->get_shape().type() != shape::float_type)
return false;
auto wei = ins->inputs().at(1)->get_shape();
assert(wei.lens().size() == 4);
auto conv = any_cast<miopen_convolution>(ins->get_operator());
......
......@@ -8,6 +8,65 @@ namespace migraph {
inline namespace MIGRAPH_INLINE_NS {
namespace gpu {
template <class... Ts>
void generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
rocblas_sgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
rocblas_dgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
void generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
rocblas_hgemm(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
void generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
MIGRAPH_THROW("Type unsupported by rocblas");
}
template <class T>
struct compute_rocblas_type
{
using type = T;
};
template <class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template <>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template <class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
rb_type<T>* to_rocblas_type(T* x)
{
return reinterpret_cast<rb_type<T>*>(x);
}
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(3);
......@@ -27,20 +86,27 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int m = output_shape.lens()[0];
rocblas_int n = output_shape.lens()[1];
rocblas_int k = args[0].get_shape().lens()[1];
rocblas_sgemm(ctx.get_stream().get_rocblas(),
output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(alpha));
auto beta_r = to_rocblas_type(as(beta));
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
generic_rocblas_gemm(as,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha,
args[1].implicit(),
&alpha_r,
to_pointer(args[1]),
ldb,
args[0].implicit(),
to_pointer(args[0]),
lda,
&beta,
args[2].implicit(),
&beta_r,
to_pointer(args[2]),
ldc);
});
return args[2];
}
......
......@@ -508,6 +508,18 @@ struct test_gemm
}
};
struct test_gemm_half
{
migraph::program create_program() const
{
migraph::program p;
auto a = p.add_parameter("a", migraph::shape{migraph::shape::half_type, {4, 5}});
auto b = p.add_parameter("b", migraph::shape{migraph::shape::half_type, {5, 3}});
p.add_instruction(migraph::op::dot{}, a, b);
return p;
}
};
struct test_gemm_ld
{
migraph::program create_program() const
......@@ -844,6 +856,7 @@ int main()
verify_program<test_global_avg_pooling>();
verify_program<test_global_max_pooling>();
verify_program<test_gemm>();
verify_program<test_gemm_half>();
// verify_program<test_gemm_ld>();
verify_program<test_gemm_transposeb>();
verify_program<test_gemm_transposea>();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment