Commit d6b4ae77 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge optimization to print flops branch

parents bdf91961 abe2a889
...@@ -25,22 +25,23 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg ...@@ -25,22 +25,23 @@ argument nonzero(hipStream_t stream, const argument& result, const argument& arg
// fill all output to 0 first // fill all output to 0 first
idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; }); idx.local_stride(out_elem_num, [&](auto j) { ptr[j] = 0; });
block_scan<block_size>(idx, block_scan<block_size>(
sum{}, idx,
0, sum{},
elem_num, 0,
[&](auto j) { return (float_equal(in_ptr[j], 0)) ? 0 : 1; }, elem_num,
[&](auto j, auto x) { [&](auto j) { return (float_equal(in_ptr[j], 0)) ? 0 : 1; },
auto out_loc = x - 1; [&](auto j, auto x) {
if(float_equal(in_ptr[j], 0)) auto out_loc = x - 1;
return; if(float_equal(in_ptr[j], 0))
return;
auto index = si.multi(j); auto index = si.multi(j);
for(size_t k = 0; k < index.size(); ++k) for(size_t k = 0; k < index.size(); ++k)
{ {
ptr[k * elem_num + out_loc] = index[k]; ptr[k * elem_num + out_loc] = index[k];
} }
}); });
}); });
}); });
......
...@@ -24,12 +24,13 @@ void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& ...@@ -24,12 +24,13 @@ void prefix_scan_sum(hipStream_t stream, const argument& result, const argument&
k[axis] = j; k[axis] = j;
return k; return k;
}; };
block_scan<block_size>(idx, block_scan<block_size>(
sum{}, idx,
0, sum{},
n, 0,
[&](auto j) { return input[compute_idx(j)]; }, n,
[&](auto j, auto x) { output[compute_idx(j)] = x; }); [&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) { output[compute_idx(j)] = x; });
}); });
}); });
} }
......
...@@ -6,12 +6,139 @@ ...@@ -6,12 +6,139 @@
#include <migraphx/gpu/device/tensor.hpp> #include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp> #include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
struct half2_sum
{
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return __hadd2(x, y); }
};
inline __device__ __half2 hmax2(__half2 x, __half2 y)
{
auto fx2 = __half22float2(x);
auto fy2 = __half22float2(y);
auto fx = fx2.x > fy2.x ? fx2.x : fy2.x;
auto fy = fx2.y > fy2.y ? fx2.y : fy2.y;
return __floats2half2_rn(fx, fy);
}
struct half2_max
{
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(__half2 x, __half2 y) const { return hmax2(x, y); }
};
// in_data is in shared memory
template <class Op>
__device__ __half2
block_reduce(__half2* buffer, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
__syncthreads();
for(index_int s = block_size; s > 0; s >>= 1)
{
if(tid < s and tid + s < batch_item_num)
{
buffer[tid] = op(buffer[tid], buffer[tid + s]);
}
__syncthreads();
}
auto lows2 = __low2half2(buffer[0]);
auto highs2 = __high2half2(buffer[0]);
return op(lows2, highs2);
}
__global__ void
softmax_kernel(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
{
__half2* input = reinterpret_cast<__half2*>(data_in);
__half2* output = reinterpret_cast<__half2*>(data_out);
batch_item_num /= 2;
extern MIGRAPHX_DEVICE_SHARED __half2 buffer2[];
__half2* in_data_reduce = buffer2;
__half2* in_data = buffer2 + batch_item_num;
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
auto d = input[i + start];
in_data[i] = d;
in_data_reduce[i] = d;
}
auto batch_max =
block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_max{});
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = h2exp(__hsub2(in_data[i], batch_max));
in_data_reduce[i] = in_data[i];
}
auto batch_sum =
block_reduce(in_data_reduce, batch_item_num, threadIdx.x, block_size, half2_sum{});
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
output[i + start] = __h2div(in_data[i], batch_sum);
}
}
// in_data is in shared memory
template <class Op>
__device__ __half
block_reduce2(__half* data, index_int batch_item_num, index_int tid, index_int block_size, Op op)
{
__syncthreads();
for(index_int s = block_size / 2; s > 0; s >>= 1)
{
if(tid < s and tid + s < batch_item_num)
{
data[tid] = op(data[tid], data[tid + s]);
}
__syncthreads();
}
return data[0];
}
__global__ void
softmax_kernel2(void* data_in, index_int batch_item_num, index_int block_size, void* data_out)
{
__half* input = reinterpret_cast<__half*>(data_in);
__half* output = reinterpret_cast<__half*>(data_out);
extern MIGRAPHX_DEVICE_SHARED __half buffer[];
__half* in_data_reduce = buffer;
__half* in_data = buffer + batch_item_num;
int start = blockIdx.x * batch_item_num;
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
auto d = input[i + start];
in_data[i] = d;
in_data_reduce[i] = d;
}
auto batch_max = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, max{});
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
in_data[i] = __float2half(::exp(__half2float(in_data[i]) - __half2float(batch_max)));
in_data_reduce[i] = in_data[i];
}
auto batch_sum = block_reduce2(in_data_reduce, batch_item_num, threadIdx.x, block_size, sum{});
for(int i = threadIdx.x; i < batch_item_num; i += block_size)
{
output[i + start] = __float2half(__half2float(in_data[i]) / __half2float(batch_sum));
}
}
void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis) void softmax(hipStream_t stream, const argument& result, const argument& arg, int64_t axis)
{ {
auto batch_lens = result.get_shape().lens(); auto batch_lens = result.get_shape().lens();
...@@ -27,25 +154,38 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -27,25 +154,38 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
if(axis == batch_lens.size() - 1) if(axis == batch_lens.size() - 1)
{ {
gs_launch(stream, batch_shape.elements() * block_size, block_size)( auto in_type = result.get_shape().type();
[=](auto i, auto idx) __device__ { if(in_type == shape::half_type and batch_item_num <= 1024)
auto start_loc = i / block_size * batch_item_num; {
auto batch_max = block_reduce<max_block_size>( auto half2_block_size = compute_block_size(batch_item_num, 1024);
idx, max{}, init, batch_item_num, [&](auto j) __device__ { int block_num = batch_shape.elements();
return input[start_loc + j]; int shared_size = batch_item_num * 2 * result.get_shape().type_size();
}); half2_block_size = half2_block_size / 4;
softmax_kernel<<<block_num, half2_block_size, shared_size, stream>>>(
arg.data(), batch_item_num, half2_block_size, result.data());
}
else
{
gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto start_loc = i / block_size * batch_item_num;
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
return input[start_loc + j];
});
auto batch_sum = block_reduce<max_block_size>( auto batch_sum = block_reduce<max_block_size>(
idx, sum{}, 0, batch_item_num, [&](auto j) __device__ { idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto val = input[start_loc + j] - batch_max; auto val = input[start_loc + j] - batch_max;
return ::exp(to_hip_type(val)); return ::exp(to_hip_type(val));
}); });
idx.local_stride(batch_item_num, [&](auto j) __device__ { idx.local_stride(batch_item_num, [&](auto j) __device__ {
auto val = input[start_loc + j] - batch_max; auto val = input[start_loc + j] - batch_max;
output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum; output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum;
});
}); });
}); }
} }
else else
{ {
......
...@@ -916,15 +916,15 @@ struct find_gemm_add ...@@ -916,15 +916,15 @@ struct find_gemm_add
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
auto copy_ins = c_ins; // auto copy_ins = c_ins;
// Insert copy // Insert copy
if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty()) // if(ins == p.end() or c_ins->outputs().size() > 1 or c_ins->inputs().empty())
{ // {
copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back()); // copy_ins = p.insert_instruction(ins, hip_copy{}, c_ins, ins->inputs().back());
} // }
inputs.push_back(copy_ins); inputs.push_back(c_ins);
inputs.push_back(copy_ins); inputs.push_back(gemm_ins->inputs().back());
gemm.beta = 1; gemm.beta = 1;
p.replace_instruction(ins, gemm, inputs); p.replace_instruction(ins, gemm, inputs);
......
...@@ -86,6 +86,16 @@ void gemm_impl(context& ctx, ...@@ -86,6 +86,16 @@ void gemm_impl(context& ctx,
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
}
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
...@@ -100,128 +110,64 @@ void gemm_impl(context& ctx, ...@@ -100,128 +110,64 @@ void gemm_impl(context& ctx,
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1) if(num_matrices == 1)
{ {
// the rocblas_gemm API handles inputs and output matrices as rocblas_invoke(&rocblas_gemm_ex,
// column-major format. When doing a C = A * B, we actually do ctx.get_stream().get_rocblas(),
// C^T = (B^T) * (A^T). That is the reason we input args[1] as transb ? rocblas_operation_transpose : rocblas_operation_none,
// A and args[0] as B in calling the rocblas_gemm. transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
if(compute_fp32) m,
rocblas_invoke(&rocblas_gemm_ex, k,
ctx.get_stream().get_rocblas(), alpha_v,
transb ? rocblas_operation_transpose : rocblas_operation_none, to_pointer(args.at(1)),
transa ? rocblas_operation_transpose : rocblas_operation_none, arg_type,
n, ldb,
m, to_pointer(args.at(0)),
k, arg_type,
&alpha, lda,
to_pointer(args.at(1)), beta_v,
arg_type, to_pointer(args[2]),
ldb, output_type,
to_pointer(args.at(0)), ldc,
arg_type, is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
lda, output_type,
&beta, ldc,
to_pointer(args[2]), compute_type,
output_type, rocblas_gemm_algo_standard,
ldc, 0,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), flag);
output_type,
ldc,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
else
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
&beta_r,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
else else
{ {
if(compute_fp32) rocblas_invoke(&rocblas_gemm_strided_batched_ex,
rocblas_invoke(&rocblas_gemm_strided_batched_ex, ctx.get_stream().get_rocblas(),
ctx.get_stream().get_rocblas(), transb ? rocblas_operation_transpose : rocblas_operation_none,
transb ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, n,
n, m,
m, k,
k, alpha_v,
&alpha, to_pointer(args.at(1)),
to_pointer(args.at(1)), arg_type,
arg_type, ldb,
ldb, k * n,
k * n, to_pointer(args.at(0)),
to_pointer(args.at(0)), arg_type,
arg_type, lda,
lda, m * k,
m * k, beta_v,
&beta, to_pointer(args[2]),
to_pointer(args[2]), output_type,
output_type, ldc,
ldc, m * n,
m * n, is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), output_type,
output_type, ldc,
ldc, m * n,
m * n, num_matrices,
num_matrices, compute_type,
compute_type, rocblas_gemm_algo_standard,
rocblas_gemm_algo_standard, 0,
0, flag);
flag);
else
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
&alpha_r,
to_pointer(args.at(1)),
arg_type,
ldb,
k * n,
to_pointer(args.at(0)),
arg_type,
lda,
m * k,
&beta_r,
to_pointer(args[2]),
output_type,
ldc,
m * n,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
m * n,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
}); });
} }
......
...@@ -109,10 +109,9 @@ argument register_on_gpu(const argument& arg) ...@@ -109,10 +109,9 @@ argument register_on_gpu(const argument& arg)
{ {
auto arg_shared = arg.share(); auto arg_shared = arg.share();
auto p = share(register_on_gpu(arg_shared.data(), arg_shared.get_shape().bytes())); auto p = share(register_on_gpu(arg_shared.data(), arg_shared.get_shape().bytes()));
return {arg_shared.get_shape(), return {arg_shared.get_shape(), [p, a = std::move(arg_shared)]() mutable {
[ p, a = std::move(arg_shared) ]() mutable {return get_device_ptr(p.get()); return get_device_ptr(p.get());
} }}; // namespace gpu
}; // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
argument to_gpu(const argument& arg, bool host) argument to_gpu(const argument& arg, bool host)
......
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_SCATTERND_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_SCATTERND_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
operation
compile_scatternd(context& ctx, const std::vector<shape>& io_shapes, const std::string& reduction);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_SCATTERND_HPP
...@@ -235,6 +235,8 @@ struct context ...@@ -235,6 +235,8 @@ struct context
this->current_device = std::make_shared<hip_device>(0, n_streams); this->current_device = std::make_shared<hip_device>(0, n_streams);
} }
any_ptr get_queue() { return get_stream().get(); }
private: private:
// TODO: Make this a vector to support multiple devices // TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device; std::shared_ptr<hip_device> current_device;
......
...@@ -21,6 +21,16 @@ struct greater ...@@ -21,6 +21,16 @@ struct greater
} }
}; };
template <class InputIt, class OutputIt>
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
{
while(first != last)
{
*d_first++ = *first++;
}
return d_first;
}
template <class Iterator, class Compare> template <class Iterator, class Compare>
constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp) constexpr Iterator is_sorted_until(Iterator first, Iterator last, Compare comp)
{ {
......
...@@ -48,7 +48,7 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=) ...@@ -48,7 +48,7 @@ MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(>=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(==) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(==)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(!=)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP (^) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(^)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(|)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(&&)
MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||) MIGRAPHX_INTEGRAL_CONSTANT_BINARY_OP(||)
......
...@@ -59,6 +59,7 @@ MIGRAPHX_DEVICE_MATH(cosh, ::cosh) ...@@ -59,6 +59,7 @@ MIGRAPHX_DEVICE_MATH(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH(erf, ::erf) MIGRAPHX_DEVICE_MATH(erf, ::erf)
MIGRAPHX_DEVICE_MATH(exp, ::exp) MIGRAPHX_DEVICE_MATH(exp, ::exp)
MIGRAPHX_DEVICE_MATH(floor, ::floor) MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(log, ::log) MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow) MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(round, ::round) MIGRAPHX_DEVICE_MATH(round, ::round)
...@@ -103,6 +104,7 @@ MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos) ...@@ -103,6 +104,7 @@ MIGRAPHX_DEVICE_MATH_HALF(cos, ::cos)
MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh) MIGRAPHX_DEVICE_MATH_HALF(cosh, ::cosh)
MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor) MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round) MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin) MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
...@@ -129,6 +131,7 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh) ...@@ -129,6 +131,7 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf) MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp) MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor) MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log) MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(pow) MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(round) MIGRAPHX_DEVICE_MATH_VEC(round)
......
#ifndef MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP
#define MIGRAPHX_GUARD_KERNELS_SCATTERND_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
struct assign_none
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x = y;
}
};
struct assign_add
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x += y;
}
};
struct assign_mul
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR void operator()(T& x, U y) const
{
x *= y;
}
};
template <class T, class U, class V, class F>
__device__ void scatternd(const T& indices_t, const U& updates_t, const V& output_t, F f)
{
auto index = make_index();
auto updates_shape = updates_t.get_shape();
index.global_stride(updates_shape.elements(), [&](auto i) {
auto output_shape = output_t.get_shape();
auto indices_shape = indices_t.get_shape();
auto k = indices_shape.lens.back();
auto q = indices_shape.lens.size();
auto updates_idx = updates_shape.multi(i);
auto indices_idx = indices_shape.multi(0);
copy(updates_idx.begin(), updates_idx.begin() + q - 1, indices_idx.begin());
auto index_start = indices_t.begin() + indices_shape.index(indices_idx);
auto index_end = index_start + k;
auto out_idx = output_shape.multi(0);
copy(index_start, index_end, out_idx.begin());
copy(updates_idx.begin() + q - 1, updates_idx.end(), out_idx.begin() + k);
f(output_t[out_idx], updates_t[i]);
});
}
} // namespace migraphx
#endif
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <migraphx/gpu/abs.hpp> #include <migraphx/gpu/abs.hpp>
#include <migraphx/gpu/batch_norm_inference.hpp> #include <migraphx/gpu/batch_norm_inference.hpp>
#include <migraphx/gpu/compile_roialign.hpp> #include <migraphx/gpu/compile_roialign.hpp>
#include <migraphx/gpu/compile_scatternd.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp> #include <migraphx/gpu/deconvolution.hpp>
...@@ -207,6 +208,7 @@ struct miopen_apply ...@@ -207,6 +208,7 @@ struct miopen_apply
add_nms_op(); add_nms_op();
add_quant_convolution_op(); add_quant_convolution_op();
add_roialign(); add_roialign();
add_scatternd();
} }
void copy_params() void copy_params()
...@@ -443,7 +445,6 @@ struct miopen_apply ...@@ -443,7 +445,6 @@ struct miopen_apply
reshapes[2], reshapes[2],
reshapes[3], reshapes[3],
output); output);
}); });
} }
...@@ -503,7 +504,6 @@ struct miopen_apply ...@@ -503,7 +504,6 @@ struct miopen_apply
void add_roialign() void add_roialign()
{ {
apply_map.emplace("roialign", [=](instruction_ref ins) { apply_map.emplace("roialign", [=](instruction_ref ins) {
auto s = ins->get_shape(); auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value(); auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s); auto output = insert_allocation(ins, s);
...@@ -516,6 +516,60 @@ struct miopen_apply ...@@ -516,6 +516,60 @@ struct miopen_apply
}); });
} }
void add_scatternd()
{
apply_map.emplace("scatternd_none", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
io_shapes.erase(io_shapes.begin());
const std::string reduction = "none";
auto co = compile_scatternd(get_context(), io_shapes, reduction);
auto copy = mod->insert_instruction(ins, make_op("hip::copy"), args.front(), output);
args.back() = copy;
args.erase(args.begin());
return mod->replace_instruction(ins, co, args);
});
apply_map.emplace("scatternd_add", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
io_shapes.erase(io_shapes.begin());
const std::string reduction = "add";
auto co = compile_scatternd(get_context(), io_shapes, reduction);
auto copy = mod->insert_instruction(ins, make_op("hip::copy"), args.front(), output);
args.back() = copy;
args.erase(args.begin());
return mod->replace_instruction(ins, co, args);
});
apply_map.emplace("scatternd_mul", [=](instruction_ref ins) {
auto s = ins->get_shape();
auto op_val = ins->get_operator().to_value();
auto output = insert_allocation(ins, s);
auto args = ins->inputs();
args.push_back(output);
auto io_shapes = to_shapes(args);
io_shapes.erase(io_shapes.begin());
const std::string reduction = "mul";
auto co = compile_scatternd(get_context(), io_shapes, reduction);
auto copy = mod->insert_instruction(ins, make_op("hip::copy"), args.front(), output);
args.back() = copy;
args.erase(args.begin());
return mod->replace_instruction(ins, co, args);
});
}
// replace the loop operator with gpu_loop operator // replace the loop operator with gpu_loop operator
void add_loop_op() void add_loop_op()
{ {
......
...@@ -15,8 +15,6 @@ target_link_libraries(migraphx_ref migraphx Threads::Threads) ...@@ -15,8 +15,6 @@ target_link_libraries(migraphx_ref migraphx Threads::Threads)
target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE}) target_include_directories(migraphx_ref PRIVATE ${BLAZE_INCLUDE})
target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS) target_compile_definitions(migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_ref)
rocm_install_targets( rocm_install_targets(
TARGETS migraphx_ref TARGETS migraphx_ref
INCLUDE INCLUDE
......
...@@ -19,7 +19,7 @@ target_compile_options(tf-proto PRIVATE -w) ...@@ -19,7 +19,7 @@ target_compile_options(tf-proto PRIVATE -w)
target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY}) target_link_libraries(tf-proto PRIVATE ${PROTOBUF_LIBRARY})
set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On) set_target_properties(tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On)
file(GLOB TF_SRCS *.cpp) file(GLOB TF_SRCS ${CONFIGURE_DEPENDS} *.cpp)
add_library(migraphx_tf ${TF_SRCS}) add_library(migraphx_tf ${TF_SRCS})
target_include_directories(migraphx_tf PRIVATE include) target_include_directories(migraphx_tf PRIVATE include)
set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf) set_target_properties(migraphx_tf PROPERTIES EXPORT_NAME tf)
......
...@@ -499,8 +499,7 @@ literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const ...@@ -499,8 +499,7 @@ literal tf_parser::parse_tensor(const tensorflow::TensorProto& t) const
return create_literal(shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size)); return create_literal(shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size)); return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF: case tensorflow::DataType::DT_HALF: {
{
std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size); std::vector<int> data_int32 = get_data_vals(t.half_val(), shape_size);
std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end()); std::vector<uint16_t> data_uint16(data_int32.begin(), data_int32.end());
std::vector<half> data_half; std::vector<half> data_half;
......
...@@ -90,7 +90,7 @@ function(add_test_executable TEST_NAME) ...@@ -90,7 +90,7 @@ function(add_test_executable TEST_NAME)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable) endfunction(add_test_executable)
file(GLOB TESTS *.cpp) file(GLOB TESTS ${CONFIGURE_DEPENDS} *.cpp)
foreach(TEST ${TESTS}) foreach(TEST ${TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -100,7 +100,7 @@ endforeach() ...@@ -100,7 +100,7 @@ endforeach()
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
# gpu tests # gpu tests
file(GLOB GPU_TESTS gpu/*.cpp) file(GLOB GPU_TESTS ${CONFIGURE_DEPENDS} gpu/*.cpp)
foreach(TEST ${GPU_TESTS}) foreach(TEST ${GPU_TESTS})
get_filename_component(BASE_NAME ${TEST} NAME_WE) get_filename_component(BASE_NAME ${TEST} NAME_WE)
...@@ -120,7 +120,7 @@ file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp) ...@@ -120,7 +120,7 @@ file (GLOB ONNX_TESTS ${TEST_ONNX_DIR}/*.cpp)
foreach(ONNX_TEST ${ONNX_TESTS}) foreach(ONNX_TEST ${ONNX_TESTS})
get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE) get_filename_component(BASE_NAME ${ONNX_TEST} NAME_WE)
set(TEST_NAME test_${BASE_NAME}) set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${TES_ONNX_DIR}/${ONNX_TEST}) add_executable(${TEST_NAME} ${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME}) rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref) target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
...@@ -160,7 +160,7 @@ function(test_header NAME HEADER) ...@@ -160,7 +160,7 @@ function(test_header NAME HEADER)
endfunction() endfunction()
function(test_headers PREFIX) function(test_headers PREFIX)
file(GLOB HEADERS ${ARGN}) file(GLOB HEADERS ${CONFIGURE_DEPENDS} ${ARGN})
foreach(HEADER ${HEADERS}) foreach(HEADER ${HEADERS})
file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER})
......
#include <migraphx/any_ptr.hpp>
#include <test.hpp>
TEST_CASE(test_int_id)
{
int i = 1;
migraphx::any_ptr p = &i;
EXPECT(p.get<int*>() == &i);
EXPECT(p.get(migraphx::get_type_name(i)) == &i);
EXPECT(p.unsafe_get() == &i);
EXPECT(test::throws([&] { p.get<float*>(); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(&i)); }));
}
TEST_CASE(test_int_name)
{
int i = 1;
void* vp = &i;
migraphx::any_ptr p{vp, migraphx::get_type_name(i)};
EXPECT(p.get<int*>() == &i);
EXPECT(p.get(migraphx::get_type_name(i)) == &i);
EXPECT(p.unsafe_get() == &i);
EXPECT(test::throws([&] { p.get<float*>(); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(&i)); }));
EXPECT(test::throws([&] { p.get(migraphx::get_type_name(float{})); }));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -12,6 +12,7 @@ endfunction() ...@@ -12,6 +12,7 @@ endfunction()
add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
#include <migraphx/migraphx.hpp>
#include <migraphx/rank.hpp>
#include "test.hpp"
template <class T>
std::false_type has_handle(migraphx::rank<0>, T)
{
return {};
}
template <class T>
auto has_handle(migraphx::rank<1>, T*) -> decltype(migraphx::as_handle<T>{}, std::true_type{})
{
return {};
}
TEST_CASE(shape)
{
static_assert(std::is_same<migraphx::as_handle<migraphx_shape>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<migraphx_shape_t>, migraphx::shape>{}, "Failed");
static_assert(std::is_same<migraphx::as_handle<const_migraphx_shape_t>, migraphx::shape>{},
"Failed");
}
TEST_CASE(non_handle)
{
int i = 0;
EXPECT(bool{has_handle(migraphx::rank<1>{}, migraphx_shape_t{})});
EXPECT(bool{not has_handle(migraphx::rank<1>{}, &i)});
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment