Commit e2eb6036 authored by Paul's avatar Paul
Browse files

Merge

parents 298c93d5 1e0bbd78
...@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal) ...@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
{ {
assert(s.standard); assert(s.standard);
assert(s.elements() > 0); assert(s.elements() > 0);
index_int n = s.elements(); index_int n = s.elements();
index_int groups = (n + nlocal - 1) / nlocal; index_int groups = (n + nlocal - 1) / nlocal;
index_int nglobal = std::min<index_int>(128, groups) * nlocal; // max possible number of blocks is set to 1B (1,073,741,824)
index_int nglobal = std::min<index_int>(1073741824, groups) * nlocal;
assert(groups > 0); assert(groups > 0);
assert(nglobal > 0); assert(nglobal > 0);
......
...@@ -44,12 +44,19 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, ...@@ -44,12 +44,19 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input,
template <index_int N, class Op, class T, class Input, class Output> template <index_int N, class Op, class T, class Input, class Output>
__device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output) __device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output)
{ {
block_scan<N>(idx, block_scan<N>(
op, idx,
init, op,
[&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); }, init,
input, [&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); },
output); input,
output);
}
template <class F>
constexpr auto reverse_scan(index_int n, F f)
{
return [=](auto i, auto&&... xs) { return f(n - i - 1, xs...); };
} }
} // namespace device } // namespace device
......
...@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f) ...@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f)
{ {
switch(n) switch(n)
{ {
case 1: case 1: {
{
f(std::integral_constant<index_int, 1>{}); f(std::integral_constant<index_int, 1>{});
break; break;
} }
case 2: case 2: {
{
f(std::integral_constant<index_int, 2>{}); f(std::integral_constant<index_int, 2>{});
break; break;
} }
case 3: case 3: {
{
f(std::integral_constant<index_int, 3>{}); f(std::integral_constant<index_int, 3>{});
break; break;
} }
case 4: case 4: {
{
f(std::integral_constant<index_int, 4>{}); f(std::integral_constant<index_int, 4>{});
break; break;
} }
case 5: case 5: {
{
f(std::integral_constant<index_int, 5>{}); f(std::integral_constant<index_int, 5>{});
break; break;
} }
......
...@@ -25,22 +25,23 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg ...@@ -25,22 +25,23 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg
// fill all output to 0 first // fill all output to 0 first
idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; }); idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; });
block_scan<block_size>(idx, block_scan<block_size>(
sum{}, idx,
0, sum{},
elem_num, 0,
[&](auto j) { return (float_equal(in_ptr[j], 0)) ? 0 : 1; }, elem_num,
[&](auto j, auto x) { [&](auto j) { return (float_equal(in_ptr[j], 0)) ? 0 : 1; },
auto out_loc = x - 1; [&](auto j, auto x) {
if(float_equal(in_ptr[j], 0)) auto out_loc = x - 1;
return; if(float_equal(in_ptr[j], 0))
return;
auto index = si.multi(j); auto index = si.multi(j);
for(size_t k = 0; k < index.size(); ++k) for(size_t k = 0; k < index.size(); ++k)
{ {
ptr[k * elem_num + out_loc] = index[k]; ptr[k * elem_num + out_loc] = index[k];
} }
}); });
}); });
}); });
......
#include <migraphx/gpu/device/prefix_scan_sum.hpp> #include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/gpu/device/scan.hpp> #include <migraphx/gpu/device/scan.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp> #include <migraphx/gpu/device/reduce_ops.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
namespace migraphx { namespace migraphx {
...@@ -8,29 +9,108 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -8,29 +9,108 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& arg, int32_t axis) void prefix_scan_sum(hipStream_t stream,
const argument& result,
const argument& arg,
int32_t axis,
bool exclusive,
bool reverse)
{ {
const index_int block_size = 256; const index_int max_block_size = 256;
const index_int n = arg.get_shape().lens()[axis]; const index_int n = arg.get_shape().lens()[axis];
auto rlens = result.get_shape().lens(); auto rlens = result.get_shape().lens();
rlens[axis] = 1; rlens[axis] = 1;
hip_visit_all(result, arg, result.get_shape().with_lens(rlens))( hip_visit_all(result, arg, result.get_shape().with_lens(rlens))(
[=](auto output, auto input, auto rshape) { [=](auto output, auto input, auto rshape) {
gs_launch(stream, rshape.elements() * block_size, block_size)( const index_int block_size = compute_block_size(rshape.elements(), max_block_size);
[=](auto i, auto idx) __device__ { if(reverse and exclusive)
const auto ridx = rshape.multi(i / block_size); {
auto compute_idx = [&](auto j) { gs_launch(stream, rshape.elements() * block_size, block_size)(
auto k = ridx; [=](auto i, auto idx) __device__ {
k[axis] = j; const auto ridx = rshape.multi(i / block_size);
return k; auto compute_idx = [&](auto j) {
}; auto k = ridx;
block_scan<block_size>(idx, k[axis] = j;
sum{}, return k;
0, };
n, block_scan<max_block_size>(
[&](auto j) { return input[compute_idx(j)]; }, idx,
[&](auto j, auto x) { output[compute_idx(j)] = x; }); sum{},
}); 0,
n,
reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }),
reverse_scan(n, [&](auto j, auto x) {
if(j == n - 1)
output[compute_idx(j)] = 0;
if(j > 0)
output[compute_idx(j - 1)] = x;
}));
});
}
else if(reverse)
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
reverse_scan(n, [&](auto j) { return input[compute_idx(j)]; }),
reverse_scan(n, [&](auto j, auto x) { output[compute_idx(j)] = x; }));
});
}
else if(exclusive)
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
[&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) {
auto k = j + 1;
if(j == 0)
output[compute_idx(0)] = 0;
if(k < n)
output[compute_idx(k)] = x;
});
});
}
else
{
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<max_block_size>(
idx,
sum{},
0,
n,
[&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) { output[compute_idx(j)] = x; });
});
}
}); });
} }
......
file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp)
add_executable(gpu-driver add_executable(gpu-driver
action.cpp ${GPU_DRIVER_SRCS}
compile_pointwise.cpp
main.cpp
parser.cpp
perf.cpp
run_op.cpp
) )
target_include_directories(gpu-driver PRIVATE include) target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu) target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
#include <migraphx/gpu/driver/action.hpp> #include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp> #include <migraphx/gpu/driver/perf.hpp>
#include <migraphx/gpu/compile_pointwise.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
...@@ -8,13 +8,13 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -8,13 +8,13 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace driver { namespace driver {
struct compile_pointwise : action<compile_pointwise> struct compile_op : action<compile_op>
{ {
static void apply(const parser& p, const value& v) static void apply(const parser& p, const value& v)
{ {
context ctx; context ctx;
auto inputs = p.parse_shapes(v.at("inputs")); auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_pointwise(ctx, inputs, v.at("lambda").to<std::string>()); auto op = gpu::compile_op(v.at("name").to<std::string>(), ctx, inputs, v);
double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100)); double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl; std::cout << op << ": " << t << "ms" << std::endl;
} }
......
...@@ -587,6 +587,11 @@ struct miopen_fusion ...@@ -587,6 +587,11 @@ struct miopen_fusion
return pack(f(self.ops, "ops")); return pack(f(self.ops, "ops"));
} }
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
value compile(context& ctx, const shape&, std::vector<shape> inputs) value compile(context& ctx, const shape&, std::vector<shape> inputs)
{ {
// Compensate for allocation // Compensate for allocation
......
...@@ -42,7 +42,8 @@ void gemm_impl(context& ctx, ...@@ -42,7 +42,8 @@ void gemm_impl(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
T alpha, T alpha,
T beta, T beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
...@@ -65,6 +66,11 @@ void gemm_impl(context& ctx, ...@@ -65,6 +66,11 @@ void gemm_impl(context& ctx,
output_type = rocblas_datatype_i32_r; output_type = rocblas_datatype_i32_r;
} }
auto compute_type = output_type; auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 #if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag = rocblas_gemm_flags flag =
...@@ -77,8 +83,19 @@ void gemm_impl(context& ctx, ...@@ -77,8 +83,19 @@ void gemm_impl(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
}
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
...@@ -104,14 +121,14 @@ void gemm_impl(context& ctx, ...@@ -104,14 +121,14 @@ void gemm_impl(context& ctx,
n, n,
m, m,
k, k,
&alpha_r, alpha_v,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
&beta_r, beta_v,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
...@@ -132,7 +149,7 @@ void gemm_impl(context& ctx, ...@@ -132,7 +149,7 @@ void gemm_impl(context& ctx,
n, n,
m, m,
k, k,
&alpha_r, alpha_v,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
...@@ -141,7 +158,7 @@ void gemm_impl(context& ctx, ...@@ -141,7 +158,7 @@ void gemm_impl(context& ctx,
arg_type, arg_type,
lda, lda,
m * k, m * k,
&beta_r, beta_v,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
...@@ -164,9 +181,10 @@ void gemm(context& ctx, ...@@ -164,9 +181,10 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
void gemm(context& ctx, void gemm(context& ctx,
...@@ -174,9 +192,10 @@ void gemm(context& ctx, ...@@ -174,9 +192,10 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format) bool int8_x4_format,
bool compute_fp32)
{ {
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
} // namespace gpu } // namespace gpu
......
...@@ -109,10 +109,9 @@ argument register_on_gpu(const argument& arg) ...@@ -109,10 +109,9 @@ argument register_on_gpu(const argument& arg)
{ {
auto arg_shared = arg.share(); auto arg_shared = arg.share();
auto p = share(register_on_gpu(arg_shared.data(), arg_shared.get_shape().bytes())); auto p = share(register_on_gpu(arg_shared.data(), arg_shared.get_shape().bytes()));
return {arg_shared.get_shape(), return {arg_shared.get_shape(), [p, a = std::move(arg_shared)]() mutable {
[ p, a = std::move(arg_shared) ]() mutable {return get_device_ptr(p.get()); return get_device_ptr(p.get());
} }}; // namespace gpu
}; // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
argument to_gpu(const argument& arg, bool host) argument to_gpu(const argument& arg, bool host)
......
...@@ -17,8 +17,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -17,8 +17,6 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
std::string enum_params(std::size_t count, std::string param); std::string enum_params(std::size_t count, std::string param);
std::size_t compute_global(std::size_t n, std::size_t local = 1024);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -8,6 +8,8 @@ namespace migraphx { ...@@ -8,6 +8,8 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct context;
struct hip_compile_options struct hip_compile_options
{ {
std::size_t global; std::size_t global;
...@@ -17,8 +19,24 @@ struct hip_compile_options ...@@ -17,8 +19,24 @@ struct hip_compile_options
std::string kernel_name = "kernel"; std::string kernel_name = "kernel";
std::string params = ""; std::string params = "";
std::vector<shape> virtual_inputs = {}; std::vector<shape> virtual_inputs = {};
/**
* @brief Set the launch parameters but allow v to override the values
*
* @param v A value class which can have a "global" and/or "local" keys to override the default
* global and local
* @param compute_global A function used to compute the global based on the local
* @param default_local The defaul local to use if its missing from the v parameter
*/
void set_launch_params(const value& v,
const std::function<std::size_t(std::size_t local)>& compute_global,
std::size_t default_local = 1024);
}; };
/// Compute global for n elements, but max out on target-specific upper limit
std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over = 1);
operation compile_hip_code_object(const std::string& content, hip_compile_options options); operation compile_hip_code_object(const std::string& content, hip_compile_options options);
} // namespace gpu } // namespace gpu
......
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct context;
operation compile_pointwise(context& ctx,
const std::vector<shape>& inputs,
const std::string& lambda,
const std::string& preamble = "");
operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, module m);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_POINTWISE_HPP
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_ROIALIGN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_ROIALIGN_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
operation compile_roialign(context& ctx, const std::vector<shape>& io_shapes, const value& val);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_ROIALIGN_HPP
#ifndef MIGRAPHX_GUARD_GPU_COMPILER_HPP
#define MIGRAPHX_GUARD_GPU_COMPILER_HPP
#include <migraphx/config.hpp>
#include <migraphx/auto_register.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/value.hpp>
#include <migraphx/module.hpp>
#include <migraphx/instruction.hpp>
#include <functional>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
using compiler_replace = std::function<void(module& m, instruction_ref ins)>;
using compiler_compile = std::function<compiler_replace(context&, instruction_ref, operation)>;
using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
void register_compiler(const std::string& name, compiler_compile c, compiler_compile_op cop);
bool has_compiler_for(const std::string& name);
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op);
operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v);
template <class T>
void register_compiler()
{
T c;
for(auto&& name : c.names())
{
register_compiler(
name,
[=](auto&&... xs) { return c.compile(std::forward<decltype(xs)>(xs)...); },
[=](auto&&... xs) { return c.compile_op(std::forward<decltype(xs)>(xs)...); });
}
}
struct register_compiler_action
{
template <class T>
static void apply()
{
register_compiler<T>();
}
};
template <class T>
using auto_register_compiler = auto_register<register_compiler_action, T>;
template <class Derived>
struct compiler : auto_register_compiler<Derived>
{
auto replace(const operation& op) const
{
return
[=](module& m, instruction_ref ins) { m.replace_instruction(ins, op, ins->inputs()); };
}
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILER_HPP
...@@ -154,6 +154,13 @@ struct hip_device ...@@ -154,6 +154,13 @@ struct hip_device
std::size_t get_cu_count() const { return device_props.multiProcessorCount; } std::size_t get_cu_count() const { return device_props.multiProcessorCount; }
std::size_t get_max_workitems_per_cu() const
{
return device_props.maxThreadsPerMultiProcessor;
}
std::size_t get_max_workitems_per_block() const { return device_props.maxThreadsPerBlock; }
private: private:
std::size_t device_id = 0; std::size_t device_id = 0;
std::size_t current_stream = 0; std::size_t current_stream = 0;
...@@ -235,6 +242,8 @@ struct context ...@@ -235,6 +242,8 @@ struct context
this->current_device = std::make_shared<hip_device>(0, n_streams); this->current_device = std::make_shared<hip_device>(0, n_streams);
} }
any_ptr get_queue() { return get_stream().get(); }
private: private:
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
......
...@@ -10,7 +10,12 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -10,7 +10,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& arg, int32_t axis); void prefix_scan_sum(hipStream_t stream,
const argument& result,
const argument& arg,
int32_t axis,
bool exclusive,
bool reverse);
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
......
...@@ -25,6 +25,7 @@ struct rocblas_gemm ...@@ -25,6 +25,7 @@ struct rocblas_gemm
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -80,11 +81,17 @@ struct rocblas_gemm ...@@ -80,11 +81,17 @@ struct rocblas_gemm
{ {
if(this->name() == "gpu::gemm") if(this->name() == "gpu::gemm")
{ {
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format); gemm(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
} }
else else
{ {
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), int8_x4_format); gemm(ctx,
output_shape,
args,
int32_t(alpha),
int32_t(beta),
int8_x4_format,
compute_fp32);
} }
return args.back(); return args.back();
} }
......
...@@ -14,13 +14,15 @@ void gemm(context& ctx, ...@@ -14,13 +14,15 @@ void gemm(context& ctx,
const std::vector<argument>& args, const std::vector<argument>& args,
float alpha, float alpha,
float beta, float beta,
bool int8_x4_format); bool int8_x4_format,
bool compute_fp32);
void gemm(context& ctx, void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args, const std::vector<argument>& args,
int32_t alpha, int32_t alpha,
int32_t beta, int32_t beta,
bool int8_x4_format); bool int8_x4_format,
bool compute_fp32);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <miopen/miopen.h> #include <miopen/miopen.h>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <sstream>
#ifdef HAS_FIND_MODE_API #ifdef HAS_FIND_MODE_API
extern "C" miopenStatus_t miopenHiddenSetConvolutionFindMode(miopenConvolutionDescriptor_t convDesc, extern "C" miopenStatus_t miopenHiddenSetConvolutionFindMode(miopenConvolutionDescriptor_t convDesc,
int findMode); int findMode);
...@@ -132,12 +134,16 @@ inline convolution_descriptor make_deconv(const T& op) ...@@ -132,12 +134,16 @@ inline convolution_descriptor make_deconv(const T& op)
inline pooling_descriptor make_pooling(const migraphx::op::pooling& op) inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
{ {
miopenPoolingMode_t mode; miopenPoolingMode_t mode;
if(op.mode == "max") if(op.mode == op::pooling_mode::max)
mode = miopenPoolingMax; mode = miopenPoolingMax;
else if(op.mode == "average") else if(op.mode == op::pooling_mode::average)
mode = miopenPoolingAverage; mode = miopenPoolingAverage;
else else
MIGRAPHX_THROW("Unknown mode for pooling: " + op.mode); {
std::stringstream ss("Unknown mode for pooling: ");
ss << op.mode;
MIGRAPHX_THROW(ss.str());
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor); auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
int kdims = op.kdims(); int kdims = op.kdims();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment