"vscode:/vscode.git/clone" did not exist on "1873dc007bc8147aa255f4d01b24d0ab11121c44"
Commit eb22c0e5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'rocblas_api_opt' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into layernorm_half2

parents 5f4e8561 b7a1823c
...@@ -152,6 +152,35 @@ struct array_base ...@@ -152,6 +152,35 @@ struct array_base
} }
}; };
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wnon-template-friend"
#endif
template <class T>
struct holder
{
// Friend injection
friend auto migraphx_adl_handle_lookup(holder<T>);
// Function left unimplemented since its only used in non-evaluated
// context
T get() const;
};
template <class C, class T>
struct handle_lookup
{
friend auto migraphx_adl_handle_lookup(holder<T>) { return holder<C>{}; }
};
#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif
template <class T>
using as_handle = decltype(
migraphx_adl_handle_lookup(holder<std::remove_cv_t<std::remove_pointer_t<T>>>{}).get());
struct own struct own
{ {
}; };
...@@ -159,8 +188,8 @@ struct borrow ...@@ -159,8 +188,8 @@ struct borrow
{ {
}; };
template <class T, class D, D Deleter, class A, A Assigner> template <class Derived, class T, class D, D Deleter, class A, A Assigner>
struct handle_base struct handle_base : handle_lookup<Derived, std::remove_cv_t<T>>
{ {
handle_base() : m_handle(nullptr) {} handle_base() : m_handle(nullptr) {}
template <class F, class... Ts> template <class F, class... Ts>
...@@ -204,7 +233,8 @@ struct handle_base ...@@ -204,7 +233,8 @@ struct handle_base
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<> #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) handle_base<>
#else #else
#define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \ #define MIGRAPHX_DETAIL_HANDLE_BASE(name, const_) \
handle_base<const_ migraphx_##name, \ handle_base<name, \
const_ migraphx_##name, \
decltype(&migraphx_##name##_destroy), \ decltype(&migraphx_##name##_destroy), \
migraphx_##name##_destroy, \ migraphx_##name##_destroy, \
decltype(&migraphx_##name##_assign_to), \ decltype(&migraphx_##name##_assign_to), \
......
#ifndef MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
#include <migraphx/config.hpp>
#include <migraphx/optional.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/type_name.hpp>
#include <cassert>
#include <string_view>
#include <typeindex>
#include <type_traits>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct any_ptr
{
any_ptr() = default;
template <class T>
any_ptr(T* p) : ptr(p), ti(typeid(T*)), name(get_name<T*>())
{
}
any_ptr(void* p, std::string_view pname) : ptr(p), name(pname) {}
void* get(std::string_view n) const
{
if(name != n)
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} +
" != " + std::string{n});
return ptr;
}
template <class T>
T get() const
{
static_assert(std::is_pointer<T>{}, "Must be a pointer");
assert(ptr != nullptr);
if(ti and std::type_index{typeid(T)} != *ti)
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name<T>());
else if(name != get_name<T>())
MIGRAPHX_THROW("any_ptr: type mismatch: " + std::string{name} + " != " + get_name<T>());
return reinterpret_cast<T>(ptr);
}
void* unsafe_get() const { return ptr; }
private:
void* ptr = nullptr;
optional<std::type_index> ti = nullopt;
std::string_view name = "";
template <class T>
static const std::string& get_name()
{
return get_type_name<std::remove_cv_t<std::remove_pointer_t<T>>>();
}
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_ANY_PTR_HPP
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -37,6 +38,12 @@ void from_value_context(T&, const value&) ...@@ -37,6 +38,12 @@ void from_value_context(T&, const value&)
{ {
} }
template <class T>
any_ptr get_queue_context(T&)
{
return {};
}
/* /*
* Type-erased interface for: * Type-erased interface for:
* *
...@@ -44,6 +51,7 @@ void from_value_context(T&, const value&) ...@@ -44,6 +51,7 @@ void from_value_context(T&, const value&)
* { * {
* value to_value() const; * value to_value() const;
* void from_value(const value& v) ; * void from_value(const value& v) ;
* any_ptr get_queue() ;
* void finish() const; * void finish() const;
* }; * };
* *
...@@ -124,6 +132,12 @@ struct context ...@@ -124,6 +132,12 @@ struct context
(*this).private_detail_te_get_handle().from_value(v); (*this).private_detail_te_get_handle().from_value(v);
} }
any_ptr get_queue()
{
assert((*this).private_detail_te_handle_mem_var);
return (*this).private_detail_te_get_handle().get_queue();
}
void finish() const void finish() const
{ {
assert((*this).private_detail_te_handle_mem_var); assert((*this).private_detail_te_handle_mem_var);
...@@ -145,6 +159,7 @@ struct context ...@@ -145,6 +159,7 @@ struct context
virtual value to_value() const = 0; virtual value to_value() const = 0;
virtual void from_value(const value& v) = 0; virtual void from_value(const value& v) = 0;
virtual any_ptr get_queue() = 0;
virtual void finish() const = 0; virtual void finish() const = 0;
}; };
...@@ -176,6 +191,19 @@ struct context ...@@ -176,6 +191,19 @@ struct context
from_value_context(private_detail_te_self, v); from_value_context(private_detail_te_self, v);
} }
template <class T>
static auto private_detail_te_default_get_queue(char, T&& private_detail_te_self)
-> decltype(private_detail_te_self.get_queue())
{
return private_detail_te_self.get_queue();
}
template <class T>
static any_ptr private_detail_te_default_get_queue(float, T&& private_detail_te_self)
{
return get_queue_context(private_detail_te_self);
}
template <typename PrivateDetailTypeErasedT> template <typename PrivateDetailTypeErasedT>
struct private_detail_te_handle_type : private_detail_te_handle_base_type struct private_detail_te_handle_type : private_detail_te_handle_base_type
{ {
...@@ -216,6 +244,12 @@ struct context ...@@ -216,6 +244,12 @@ struct context
private_detail_te_default_from_value(char(0), private_detail_te_value, v); private_detail_te_default_from_value(char(0), private_detail_te_value, v);
} }
any_ptr get_queue() override
{
return private_detail_te_default_get_queue(char(0), private_detail_te_value);
}
void finish() const override { private_detail_te_value.finish(); } void finish() const override { private_detail_te_value.finish(); }
PrivateDetailTypeErasedT private_detail_te_value; PrivateDetailTypeErasedT private_detail_te_value;
......
...@@ -42,7 +42,8 @@ void gemm_impl(context& ctx, ...@@ -42,7 +42,8 @@ void gemm_impl(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
T alpha, T alpha,
T beta, T beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
...@@ -65,6 +66,11 @@ void gemm_impl(context& ctx, ...@@ -65,6 +66,11 @@ void gemm_impl(context& ctx,
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
auto compute_type = output_type; auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 #if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag = rocblas_gemm_flags flag =
...@@ -77,8 +83,20 @@ void gemm_impl(context& ctx, ...@@ -77,8 +83,20 @@ void gemm_impl(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta); auto alpha_r = as(alpha);
auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v{&alpha_r};
void* beta_v{&beta_r};
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
}
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
...@@ -93,10 +111,6 @@ void gemm_impl(context& ctx, ...@@ -93,10 +111,6 @@ void gemm_impl(context& ctx,
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1) if(num_matrices == 1)
{ {
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
rocblas_invoke(&rocblas_gemm_ex, rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -104,14 +118,14 @@ void gemm_impl(context& ctx, ...@@ -104,14 +118,14 @@ void gemm_impl(context& ctx,
n, n,
m, m,
k, k,
&alpha_r, alpha_v,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
&beta_r, beta_v,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
...@@ -132,7 +146,7 @@ void gemm_impl(context& ctx, ...@@ -132,7 +146,7 @@ void gemm_impl(context& ctx,
n, n,
m, m,
k, k,
&alpha_r, alpha_v,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
...@@ -141,7 +155,7 @@ void gemm_impl(context& ctx, ...@@ -141,7 +155,7 @@ void gemm_impl(context& ctx,
arg_type, arg_type,
lda, lda,
m * k, m * k,
&beta_r, beta_v,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
...@@ -164,9 +178,10 @@ void gemm(context& ctx, ...@@ -164,9 +178,10 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
void gemm(context& ctx, void gemm(context& ctx,
...@@ -174,9 +189,10 @@ void gemm(context& ctx, ...@@ -174,9 +189,10 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
} // namespace gpu } // namespace gpu
......
...@@ -235,6 +235,8 @@ struct context ...@@ -235,6 +235,8 @@ struct context
this->current_device = std::make_shared<hip_device>(0, n_streams); this->current_device = std::make_shared<hip_device>(0, n_streams);
} }
any_ptr get_queue() { return get_stream().get(); }
private: private:
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
......
...@@ -25,6 +25,7 @@ struct rocblas_gemm ...@@ -25,6 +25,7 @@ struct rocblas_gemm
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -80,11 +81,17 @@ struct rocblas_gemm ...@@ -80,11 +81,17 @@ struct rocblas_gemm
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm")
{ {
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
else else
{ {
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), int8_x4_format); gemm(ctx,
output_shape,
args,
int32_t(alpha),
int32_t(beta),
int8_x4_format,
compute_fp32);
} }
return args.back(); return args.back();
} }
......
...@@ -14,13 +14,15 @@ void gemm(context& ctx, ...@@ -14,13 +14,15 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format); bool int8_x4_format,
bool compute_fp32);
void gemm(context& ctx, void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format); bool int8_x4_format,
bool compute_fp32);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp> #include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/elu.hpp> #include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp> #include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
...@@ -60,6 +61,7 @@ struct miopen_apply ...@@ -60,6 +61,7 @@ struct miopen_apply
std::unordered_map<instruction_ref, std::string> prog_output_names{}; std::unordered_map<instruction_ref, std::string> prog_output_names{};
bool offload_copy = false; bool offload_copy = false;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false;
context& get_context() const context& get_context() const
{ {
...@@ -96,13 +98,22 @@ struct miopen_apply ...@@ -96,13 +98,22 @@ struct miopen_apply
} }
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
void init() void init()
{ {
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 #if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto& ctx = get_context(); auto& ctx = get_context();
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag; rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag); rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4); int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
...@@ -337,7 +348,7 @@ struct miopen_apply ...@@ -337,7 +348,7 @@ struct miopen_apply
} }
} }
return mod->replace_instruction( return mod->replace_instruction(
ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format}, refs); ins, rocblas_gemm<Op>{Op{}, 1, 0, int8_x4_format, compute_fp32}, refs);
}); });
} }
......
#include <migraphx/any_ptr.hpp>
#include <test.hpp>
TEST_CASE(test_int_id)
{
int i = 1;
migraphx::any_ptr p = &i;
EXPECT(p.get<int*>() == &i);
EXPECT(p.get(migraphx::get_type_name(i)) == &i);
EXPECT(p.unsafe_get() == &i);
EXPECT(test::throws([&] { p.get<float*>(); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(&i)); }));
}
TEST_CASE(test_int_name)
{
int i = 1;
void* vp = &i;
migraphx::any_ptr p{vp, migraphx::get_type_name(i)};
EXPECT(p.get<int*>() == &i);
EXPECT(p.get(migraphx::get_type_name(i)) == &i);
EXPECT(p.unsafe_get() == &i);
EXPECT(test::throws([&] { p.get<float*>(); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(&i)); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(float{})); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -12,6 +12,7 @@ endfunction() ...@@ -12,6 +12,7 @@ endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.hpp>
#include <migraphx/rank.hpp>
#include "test.hpp"
template <class T>
std::false_type has_handle(migraphx::rank<0>, T)
{
return {};
}
template <class T>
auto has_handle(migraphx::rank<1>, T*) -> decltype(migraphx::as_handle<T>{}, std::true_type{})
{
return {};
}
TEST_CASE(shape)
{
static_assert(std::is_same<migraphx::as_handle<migraphx_shape>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<migraphx_shape_t>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<const_migraphx_shape_t>, migraphx::shape>{},
"Failed");
}
TEST_CASE(non_handle)
{
int i = 0;
EXPECT(bool{has_handle(migraphx::rank<1>{}, migraphx_shape_t{})});
EXPECT(bool{not has_handle(migraphx::rank<1>{}, &i)});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(gpu_context) TEST_CASE(gpu_context_serialize)
{ {
migraphx::context ctx = migraphx::gpu::context{0, 3}; migraphx::context ctx = migraphx::gpu::context{0, 3};
...@@ -25,4 +25,10 @@ TEST_CASE(gpu_context) ...@@ -25,4 +25,10 @@ TEST_CASE(gpu_context)
EXPECT(v == v1); EXPECT(v == v1);
} }
TEST_CASE(context_queue)
{
migraphx::context ctx = migraphx::gpu::context{0, 3};
EXPECT(ctx.get_queue().get<hipStream_t>() != nullptr);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <utility> #include <utility>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/value.hpp> #include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -33,12 +34,21 @@ value to_value_context(const T&) ...@@ -33,12 +34,21 @@ value to_value_context(const T&)
} }
template <class T> template <class T>
void from_value_context(T&, const value&){} void from_value_context(T&, const value&)
{
}
template <class T>
any_ptr get_queue_context(T&)
{
return {};
}
<% <%
interface('context', interface('context',
virtual('to_value', returns = 'value', const = True, default = 'to_value_context'), virtual('to_value', returns = 'value', const = True, default = 'to_value_context'),
virtual('from_value', v = 'const value&', default = 'from_value_context'), virtual('from_value', v = 'const value&', default = 'from_value_context'),
virtual('get_queue', returns = 'any_ptr', default = 'get_queue_context'),
virtual('finish', returns = 'void', const = True)) %> virtual('finish', returns = 'void', const = True)) %>
inline void migraphx_to_value(value& v, const context& ctx) inline void migraphx_to_value(value& v, const context& ctx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment