Commit 7dc6e3ae authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into mi100_opts

parents f94d77fc a275f590
......@@ -392,8 +392,10 @@ struct cpu_apply
extend_op("concat", "dnnl::concat");
extend_op("contiguous", "dnnl::reorder");
extend_op("convolution", "dnnl::convolution");
#ifndef MIGRAPHX_ENABLE_ZENDNN
extend_op("deconvolution", "dnnl::deconvolution");
extend_op("dot", "dnnl::dot");
#endif
extend_op("erf", "cpu::erf");
extend_op("gather", "cpu::gather");
extend_op("logsoftmax", "dnnl::logsoftmax");
......
......@@ -12,7 +12,7 @@ struct dnnl_lrn : dnnl_extend_op<dnnl_lrn, dnnl::lrn_forward, op::lrn>
{
return {dnnl::prop_kind::forward_inference,
dnnl::algorithm::lrn_across_channels,
m.at(DNNL_ARG_SRC_0),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)),
this->op.size,
this->op.alpha,
this->op.beta,
......
......@@ -125,7 +125,7 @@ template struct cpu_pooling<max_pool>;
struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::pooling>
{
std::vector<int> arg_map(int) const { return {DNNL_ARG_SRC}; }
std::vector<int> arg_map(int) const { return {MIGRAPHX_DNNL_PREFIX(ARG_SRC)}; }
dnnl::pooling_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
......@@ -135,8 +135,8 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
std::vector<size_t> padding_r(op.padding.begin() + kdims, op.padding.end());
return {dnnl::prop_kind::forward_inference,
algo,
m.at(DNNL_ARG_SRC),
m.at(DNNL_ARG_DST),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
to_dnnl_dims(op.stride),
to_dnnl_dims(op.lengths),
to_dnnl_dims(padding_l),
......
File mode changed from 100755 to 100644
......@@ -37,7 +37,11 @@ struct dnnl_reduction : dnnl_op<dnnl_reduction, dnnl::reduction>
dnnl::reduction::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {to_dnnl_algo(algo), m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_DST), 0, 0};
return {to_dnnl_algo(algo),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)),
m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST)),
0,
0};
}
};
......
......@@ -27,7 +27,7 @@ struct dnnl_reorder : dnnl_op<dnnl_reorder, dnnl::reorder>
};
desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
return {m.at(DNNL_ARG_SRC), m.at(DNNL_ARG_DST)};
return {m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC)), m.at(MIGRAPHX_DNNL_PREFIX(ARG_DST))};
}
auto get_primitive_desc(const desc& d, const dnnl::primitive_attr& attr) const
......
......@@ -11,7 +11,7 @@ struct dnnl_softmax : dnnl_extend_op<dnnl_softmax, dnnl::softmax_forward, op::so
dnnl::softmax_forward::desc get_desc(const std::unordered_map<int, dnnl::memory::desc>& m) const
{
int axis = this->op.axis;
return {dnnl::prop_kind::forward_inference, m.at(DNNL_ARG_SRC_0), axis};
return {dnnl::prop_kind::forward_inference, m.at(MIGRAPHX_DNNL_PREFIX(ARG_SRC_0)), axis};
}
};
......
......@@ -22,6 +22,7 @@
#include <migraphx/schedule.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/cpu/fuse_ops.hpp>
......
......@@ -41,6 +41,7 @@ add_library(migraphx_device
device/equal.cpp
device/erf.cpp
device/exp.cpp
device/fill.cpp
device/floor.cpp
device/gather.cpp
device/gelu.cpp
......@@ -84,7 +85,9 @@ add_library(migraphx_device
device/sub.cpp
device/tan.cpp
device/tanh.cpp
device/topk.cpp
device/unary_not.cpp
device/where.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION})
......@@ -119,6 +122,7 @@ add_library(migraphx_gpu
code_object_op.cpp
compile_hip.cpp
compile_hip_code_object.cpp
compile_pointwise.cpp
concat.cpp
convert.cpp
convolution.cpp
......@@ -135,6 +139,7 @@ add_library(migraphx_gpu
kernel.cpp
lowering.cpp
logsoftmax.cpp
loop.cpp
lrn.cpp
leaky_relu.cpp
mlir_conv.cpp
......@@ -151,6 +156,7 @@ add_library(migraphx_gpu
softmax.cpp
sync_device.cpp
target.cpp
topk.cpp
write_literals.cpp
)
set_target_properties(migraphx_gpu PROPERTIES EXPORT_NAME gpu)
......@@ -189,6 +195,7 @@ register_migraphx_gpu_ops(hip_
logical_and
logical_or
logical_xor
loop
max
min
mul
......@@ -217,7 +224,9 @@ register_migraphx_gpu_ops(hip_
sub
tanh
tan
topk
unary_not
where
)
register_migraphx_gpu_ops(miopen_
abs
......@@ -283,19 +292,27 @@ if(MIGRAPHX_ENABLE_MLIR)
target_link_libraries(migraphx_gpu PUBLIC ${LIBMLIRMIOPEN})
endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "")
if(MIGRAPHX_USE_HIPRTC)
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
else()
# Get flags needed to compile hip
include(TargetFlags)
target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags
string(REGEX REPLACE "--cuda-gpu-arch=[^ \t\r\n]+" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE "--offload-arch=[^ \t\r\n]+" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REPLACE "$<LINK_LANGUAGE:CXX>" "1" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REPLACE "SHELL:" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
"-DMIGRAPHX_HIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}"
"-DMIGRAPHX_OFFLOADBUNDLER_BIN=${MIGRAPHX_OFFLOADBUNDLER_BIN}"
"-DMIGRAPHX_EXTRACT_KERNEL=${MIGRAPHX_EXTRACT_KERNEL}"
"-DMIGRAPHX_USE_HIPRTC=0"
)
endif()
# Check miopen find mode api
include(CheckLibraryExists)
......@@ -313,6 +330,8 @@ target_compile_definitions(migraphx_gpu PUBLIC -D__HIP_PLATFORM_HCC__=1)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
add_subdirectory(driver)
rocm_install_targets(
TARGETS migraphx_gpu migraphx_device
INCLUDE
......
File mode changed from 100644 to 100755
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/env.hpp>
#include <cassert>
#include <iostream>
#if MIGRAPHX_USE_HIPRTC
#include <hip/hiprtc.h>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/env.hpp>
#else
#include <migraphx/compile_src.hpp>
#include <migraphx/process.hpp>
#include <cassert>
#endif
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_OPTIMIZE);
#if MIGRAPHX_USE_HIPRTC
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_HIPRTC)
std::string hiprtc_error(hiprtcResult err, const std::string& msg)
{
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
}
void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::string& ctx)
{
if(err != HIPRTC_SUCCESS)
throw make_exception(ctx, hiprtc_error(err, msg));
}
#define MIGRAPHX_HIPRTC(...) \
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, MIGRAPHX_MAKE_SOURCE_CTX())
#define MIGRAPHX_HIPRTC_THROW(error, msg) MIGRAPHX_THROW(hiprtc_error(error, msg))
// Workaround hiprtc's broken API
void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); }
using hiprtc_program_ptr = MIGRAPHX_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy);
template <class... Ts>
hiprtc_program_ptr hiprtc_program_create(Ts... xs)
{
hiprtcProgram prog = nullptr;
auto result = hiprtcCreateProgram(&prog, xs...);
hiprtc_program_ptr p{prog};
if(result != HIPRTC_SUCCESS)
MIGRAPHX_HIPRTC_THROW(result, "Create program failed.");
return p;
}
struct hiprtc_program
{
struct string_array
{
std::vector<std::string> strings{};
std::vector<const char*> c_strs{};
string_array() {}
string_array(const string_array&) = delete;
std::size_t size() const { return strings.size(); }
const char** data() { return c_strs.data(); }
void push_back(std::string s)
{
strings.push_back(std::move(s));
c_strs.push_back(strings.back().c_str());
}
};
hiprtc_program_ptr prog = nullptr;
string_array headers{};
string_array include_names{};
std::string cpp_src = "";
std::string cpp_name = "";
hiprtc_program(const std::vector<src_file>& srcs)
{
for(auto&& src : srcs)
{
std::string content{src.content.first, src.content.second};
std::string path = src.path.string();
if(src.path.extension().string() == ".cpp")
{
cpp_src = std::move(content);
cpp_name = std::move(path);
}
else
{
headers.push_back(std::move(content));
include_names.push_back(std::move(path));
}
}
prog = hiprtc_program_create(cpp_src.c_str(),
cpp_name.c_str(),
headers.size(),
headers.data(),
include_names.data());
}
void compile(const std::vector<std::string>& options)
{
if(enabled(MIGRAPHX_TRACE_HIPRTC{}))
std::cout << "hiprtc " << join_strings(options, " ") << " " << cpp_name << std::endl;
std::vector<const char*> c_options;
std::transform(options.begin(),
options.end(),
std::back_inserter(c_options),
[](const std::string& s) { return s.c_str(); });
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
std::cerr << log() << std::endl;
if(result != HIPRTC_SUCCESS)
MIGRAPHX_HIPRTC_THROW(result, "Compilation failed.");
}
std::string log()
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
if(n < 2)
return {};
std::vector<char> buffer(n);
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
assert(buffer.back() == 0);
return {buffer.begin(), buffer.end() - 1};
}
std::vector<char> get_code_obj()
{
std::size_t n = 0;
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
std::vector<char> buffer(n);
MIGRAPHX_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
return buffer;
}
};
std::vector<std::vector<char>>
compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std::string& arch)
{
hiprtc_program prog(srcs);
auto options = split_string(params, ' ');
if(enabled(MIGRAPHX_GPU_DEBUG{}))
options.push_back("-DMIGRAPHX_DEBUG");
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
return starts_with(s, "--std=") or starts_with(s, "-std=");
}))
options.push_back("-std=c++17");
options.push_back("-fno-gpu-rdc");
options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-Wno-cuda-compat");
options.push_back("--cuda-gpu-arch=" + arch);
prog.compile(options);
return {prog.get_code_obj()};
}
#else // MIGRAPHX_USE_HIPRTC
bool is_hcc_compiler()
{
static const auto result = ends_with(MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER), "hcc");
......@@ -41,9 +197,12 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
params += " --cuda-gpu-arch=" + arch;
params += " --cuda-device-only";
params += " -O3 ";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
}
if(enabled(MIGRAPHX_GPU_DEBUG{}))
params += " -DMIGRAPHX_DEBUG";
params += " -Wno-unused-command-line-argument -Wno-cuda-compat ";
params += MIGRAPHX_STRINGIZE(MIGRAPHX_HIP_COMPILER_FLAGS);
......@@ -71,6 +230,8 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
return {compiler.compile(srcs)};
}
#endif // MIGRAPHX_USE_HIPRTC
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -68,6 +68,31 @@ __content__
return replace_string(args_hpp, "__content__", inner);
}
const std::vector<std::string>& compiler_warnings()
{
static std::vector<std::string> warnings = {"-Weverything",
"-Wno-c++98-compat",
"-Wno-c++98-compat-pedantic",
"-Wno-conversion",
"-Wno-double-promotion",
"-Wno-exit-time-destructors",
"-Wno-extra-semi",
"-Wno-extra-semi-stmt",
"-Wno-float-conversion",
"-Wno-gnu-anonymous-struct",
"-Wno-gnu-zero-variadic-macro-arguments",
"-Wno-missing-prototypes",
"-Wno-nested-anon-types",
"-Wno-padded",
"-Wno-shorten-64-to-32",
"-Wno-sign-conversion",
"-Wno-sign-compare",
"-Wno-unused-command-line-argument",
"-Wno-weak-vtables",
"-Wno-c99-extensions"};
return warnings;
}
operation compile_hip_code_object(const std::string& content, hip_compile_options options)
{
std::vector<src_file> srcs;
......@@ -82,10 +107,14 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
});
srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp = generate_args_hpp(options.inputs);
auto args_hpp =
generate_args_hpp(options.reduced_inputs.empty() ? options.inputs : options.reduced_inputs);
srcs.push_back(src_file{fs::path{"args.hpp"},
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
options.params += " -I.";
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
if(cos.size() != 1)
MIGRAPHX_THROW("No code object");
......
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
using namespace migraphx;
extern "C" {
__global__ void kernel(${params})
{
pointwise(${lambda}, ${args});
}
}
int main() {}
)__migraphx__";
std::string enum_params(std::size_t count, std::string param)
{
std::vector<std::string> items(count);
transform(range(count), items.begin(), [&](auto i) { return param + std::to_string(i); });
return join_strings(items, ",");
}
std::size_t compute_global(std::size_t n, std::size_t local = 1024)
{
std::size_t groups = (n + local - 1) / local;
std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return nglobal;
}
operation compile_pointwise(context&, const std::vector<shape>& inputs, const std::string& lambda)
{
hip_compile_options options;
options.global = compute_global(inputs.front().elements());
options.local = 1024;
options.inputs = inputs;
options.output = inputs.back();
options.reduced_inputs = reduce_dims(inputs);
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", lambda}});
return compile_hip_code_object(src, options);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/fill.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void fill(hipStream_t stream, const argument& result, unsigned long val)
{
nary(stream, result)([=]() __device__ { return val; });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -352,7 +352,8 @@ bool broadcastable(bool& divisible_by_4,
auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(bshape.lens()[b_idx] == b_len);
if(b_len <= max_size and std::none_of(std::next(b_it), strides.end(), not_zero))
if(b_len <= max_size and std::none_of(std::next(b_it), strides.end(), not_zero) and
is_divisor_encodable(b_stride * b_len))
{
divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
......
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/topk.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/visit.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class T, class Index, class Compare>
struct hip_heap_vector
{
MIGRAPHX_DEVICE_CONSTEXPR hip_heap_vector(T* val, index_int n, Index v_idx, Compare comp)
: data(val), size(n), data_index(v_idx), compare(comp)
{
make_heap(size);
}
MIGRAPHX_DEVICE_CONSTEXPR void try_push(const T val)
{
if(compare(val, data[data_index(0)]))
return;
pop_heap(size - 1);
data[data_index(size - 1)] = val;
push_heap(size - 1);
}
MIGRAPHX_DEVICE_CONSTEXPR void sort() { sort_heap(size); }
private:
MIGRAPHX_DEVICE_CONSTEXPR inline static void swap(T& v1, T& v2)
{
T v = v1;
v1 = v2;
v2 = v;
}
MIGRAPHX_DEVICE_CONSTEXPR inline void heapify_down(index_int n, index_int index)
{
while(index < n)
{
auto pre_index = index;
index_int l = 2 * index + 1;
index_int r = 2 * index + 2;
if(l < n && compare(data[data_index(l)], data[data_index(index)]))
{
index = l;
}
if(r < n && compare(data[data_index(r)], data[data_index(index)]))
{
index = r;
if(compare(data[data_index(l)], data[data_index(r)]))
{
index = l;
}
}
if(index == pre_index)
{
break;
}
swap(data[data_index(index)], data[data_index(pre_index)]);
}
}
MIGRAPHX_DEVICE_CONSTEXPR inline void heapify_up(index_int index)
{
while(index > 0)
{
auto parent_idx = (index - 1) / 2;
if(not compare(data[data_index(index)], data[data_index(parent_idx)]))
{
break;
}
swap(data[data_index(index)], data[data_index(parent_idx)]);
index = parent_idx;
}
}
MIGRAPHX_DEVICE_CONSTEXPR inline void make_heap(index_int n)
{
for(int j = n / 2 - 1; j >= 0; --j)
{
heapify_down(n, j);
}
}
MIGRAPHX_DEVICE_CONSTEXPR inline void push_heap(index_int loc) { heapify_up(loc); }
MIGRAPHX_DEVICE_CONSTEXPR inline void pop_heap(index_int loc)
{
swap(data[data_index(0)], data[data_index(loc)]);
heapify_down(loc, 0);
}
MIGRAPHX_DEVICE_CONSTEXPR inline void sort_heap(index_int n)
{
for(int j = n - 1; j > 0; --j)
{
swap(data[data_index(0)], data[data_index(j)]);
heapify_down(j, 0);
}
}
T* data = nullptr;
index_int size;
Index data_index;
Compare compare;
};
template <class T, class Index, class Compare>
__device__ hip_heap_vector<T, Index, Compare>
make_heap(T* data, index_int n, Index idx, Compare compare)
{
return {data, n, idx, compare};
}
template <class Compare>
std::vector<argument> topk(hipStream_t stream,
const argument& val_res,
const argument& ind_res,
const argument& arg,
int64_t k,
int64_t axis,
Compare compare)
{
auto in_s = arg.get_shape();
auto in_lens = in_s.lens();
auto out_s = val_res.get_shape();
auto axis_dim = in_s.lens()[axis];
auto comp_lens = in_lens;
comp_lens[axis] = 1;
shape comp_s{in_s.type(), comp_lens};
std::size_t elem_num = comp_s.elements();
hip_visit_all(val_res, arg, out_s, in_s, comp_s)(
[&](auto out_val, auto input, auto oss, auto iss, auto css) {
auto* data = device_cast(input.data());
auto* out = device_cast(out_val.data());
auto* const ind = ind_res.cast<int64_t>();
gs_launch(stream, elem_num)([=](auto i) __device__ {
auto idx = css.multi(i);
auto in_idx = [&](int ii) {
auto iidx = idx;
iidx[axis] = ii;
return iss.index(iidx);
};
auto out_idx = [&](int ii) {
auto iidx = idx;
iidx[axis] = ii;
return oss.index(iidx);
};
auto data_compare = [=](auto ii, auto jj) {
return compare(data[in_idx(ii)], data[in_idx(jj)]);
};
for(int j = 0; j < k; ++j)
{
ind[out_idx(j)] = j;
}
auto hp = make_heap(ind, k, out_idx, data_compare);
for(int j = k; j < axis_dim; ++j)
{
hp.try_push(j);
}
hp.sort();
for(int j = 0; j < k; ++j)
{
out[out_idx(j)] = data[in_idx(ind[out_idx(j)])];
}
});
});
return {val_res, ind_res};
}
argument topk_largest(hipStream_t stream,
const argument& val_res,
const argument& ind_res,
const argument& arg,
int64_t k,
int64_t axis)
{
return {topk(stream, val_res, ind_res, arg, k, axis, std::less<>{})};
}
argument topk_smallest(hipStream_t stream,
const argument& val_res,
const argument& ind_res,
const argument& arg,
int64_t k,
int64_t axis)
{
return {topk(stream, val_res, ind_res, arg, k, axis, std::greater<>{})};
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device/where.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/launch.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <class Shape>
constexpr auto get_rank(const Shape&)
{
return decltype(typename Shape::hip_index{}.size()){};
}
void where(hipStream_t stream,
const argument& result,
const argument& arg0,
const argument& arg1,
const argument& arg2)
{
hip_visit_all(result, arg1, arg2)([&](auto output, auto x, auto y) {
hip_visit_all(arg0)([&](auto cond) {
if constexpr(get_rank(cond.get_shape()) == get_rank(output.get_shape()))
{
gs_launch(stream, arg1.get_shape().elements())([=](auto idx) __device__ {
auto i = output.get_shape().multi(idx);
output[i] = cond[i] ? x[i] : y[i];
});
}
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
add_executable(gpu-driver
action.cpp
compile_pointwise.cpp
main.cpp
parser.cpp
perf.cpp
run_op.cpp
)
target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/errors.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace driver {
auto& action_map()
{
static std::unordered_map<std::string, action_function> m;
return m;
}
action_function get_action(const std::string& name)
{
if(action_map().count(name) == 0)
MIGRAPHX_THROW("Missing action: " + name);
return action_map().at(name);
}
void register_action(const std::string& name, const action_function& a) { action_map()[name] = a; }
} // namespace driver
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/driver/perf.hpp>
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace driver {
struct compile_pointwise : action<compile_pointwise>
{
static void apply(const parser& p, const value& v)
{
context ctx;
auto inputs = p.parse_shapes(v.at("inputs"));
auto op = gpu::compile_pointwise(ctx, inputs, v.at("lambda").to<std::string>());
double t = time_op(ctx, op, inputs, p.get(v, "iterations", 100));
std::cout << op << ": " << t << "ms" << std::endl;
}
};
} // namespace driver
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment