Commit 15bf3d62 authored by Paul's avatar Paul
Browse files

Merge branch 'master' into memory_coloring

parents 7fa4d978 d2778c9e
...@@ -3,6 +3,7 @@ add_library(migraph ...@@ -3,6 +3,7 @@ add_library(migraph
auto_contiguous.cpp auto_contiguous.cpp
dead_code_elimination.cpp dead_code_elimination.cpp
eliminate_contiguous.cpp eliminate_contiguous.cpp
fwd_conv_batchnorm_rewrite.cpp
env.cpp env.cpp
generate.cpp generate.cpp
program.cpp program.cpp
......
#include <migraph/fwd_conv_batchnorm_rewrite.hpp>
#include <migraph/program.hpp>
#include <migraph/instruction.hpp>
#include <migraph/operators.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/dfor.hpp>
namespace migraph {
void fwd_conv_batchnorm_rewrite::apply(program& p) const
{
for(auto ins : iterator_for(p))
{
if(ins->op.name() != "batch_norm_inference")
continue;
if(not std::all_of(ins->arguments.begin() + 1, ins->arguments.end(), [](auto arg) {
return arg->op.name() == "@literal";
}))
continue;
auto conv_ins = ins->arguments[0];
if(conv_ins->op.name() != "convolution")
continue;
if(conv_ins->arguments[1]->op.name() != "@literal")
continue;
// Get scale, bias, mean, variance from instruction_ref
const auto& gamma = ins->arguments[1]->get_literal();
const auto& bias = ins->arguments[2]->get_literal();
const auto& mean = ins->arguments[3]->get_literal();
const auto& variance = ins->arguments[4]->get_literal();
// Get epsilon
auto bn_op = any_cast<batch_norm_inference>(ins->op);
auto epsilon = bn_op.epsilon;
// Get convolution weights
const auto& weights = conv_ins->arguments[1]->get_literal();
// Get convolution op
auto conv_op = conv_ins->op;
auto weights_lens = weights.get_shape().lens();
auto conv_lens = conv_ins->get_shape().lens();
argument new_weights{weights.get_shape()};
argument new_bias{bias.get_shape()};
visit_all(weights, gamma, bias, mean, variance, new_weights, new_bias)(
[&](auto weights2,
auto gamma2,
auto bias2,
auto mean2,
auto variance2,
auto new_weights2,
auto new_bias2) {
dfor(weights_lens[0], weights_lens[1], weights_lens[2], weights_lens[3])(
[&](std::size_t k, std::size_t c, std::size_t h, std::size_t w) {
new_weights2(k, c, h, w) =
gamma2(k) / std::sqrt(variance2(k) + epsilon) * weights2(k, c, h, w);
});
dfor(new_bias.get_shape().elements())([&](std::size_t c) {
new_bias2(c) = bias2(c) - (mean2(c) / std::sqrt(variance2(c) + epsilon));
});
});
// Replace convolution instruction with updated weights
auto l_weights = p.add_literal({weights.get_shape(), new_weights.data()});
auto l_bias = p.add_literal({new_bias.get_shape(), new_bias.data()});
auto c = p.replace_instruction(conv_ins, conv_op, {conv_ins->arguments[0], l_weights});
auto b = p.insert_instruction(ins, broadcast{1}, c, l_bias);
p.replace_instruction(ins, add{}, {c, b});
}
}
} // namespace migraph
...@@ -5,6 +5,14 @@ ...@@ -5,6 +5,14 @@
namespace migraph { namespace migraph {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
namespace detail { namespace detail {
template <class R, class F> template <class R, class F>
...@@ -19,8 +27,48 @@ struct fix_f ...@@ -19,8 +27,48 @@ struct fix_f
} }
}; };
template <std::size_t...>
struct seq
{
using type = seq;
};
template <class, class>
struct merge_seq;
template <std::size_t... Xs, std::size_t... Ys>
struct merge_seq<seq<Xs...>, seq<Ys...>> : seq<Xs..., (sizeof...(Xs) + Ys)...>
{
};
template <std::size_t N>
struct gens : merge_seq<typename gens<N / 2>::type, typename gens<N - N / 2>::type>
{
};
template <>
struct gens<0> : seq<>
{
};
template <>
struct gens<1> : seq<0>
{
};
template <class F, std::size_t... Ns>
constexpr void repeat_c_impl(F f, seq<Ns...>)
{
swallow{(f(std::integral_constant<std::size_t, Ns>{}), 0)...};
}
} // namespace detail } // namespace detail
template <std::size_t N, class F>
constexpr void repeat_c(F f)
{
detail::repeat_c_impl(f, detail::gens<N>{});
}
/// Implements a fix-point combinator /// Implements a fix-point combinator
template <class R, class F> template <class R, class F>
detail::fix_f<R, F> fix(F f) detail::fix_f<R, F> fix(F f)
...@@ -35,7 +83,7 @@ auto fix(F f) ...@@ -35,7 +83,7 @@ auto fix(F f)
} }
template <class... Ts> template <class... Ts>
auto make_sequence(Ts... xs) auto pack(Ts... xs)
{ {
return [=](auto f) { return f(xs...); }; return [=](auto f) { return f(xs...); };
} }
......
#ifndef MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#define MIGRAPH_GUARD_RTGLIB_FWD_CONV_BATCHNORM_REWRITE_HPP
#include <string>
#include <migraph/instruction_ref.hpp>
namespace migraph {
struct program;
struct fwd_conv_batchnorm_rewrite
{
std::string name() const { return "fwd_conv_batchnorm_rewrite"; }
void apply(program& p) const;
};
} // namespace migraph
#endif
...@@ -12,7 +12,11 @@ constexpr T normalize(unsigned long z) ...@@ -12,7 +12,11 @@ constexpr T normalize(unsigned long z)
{ {
if(z == 0) if(z == 0)
return 0; return 0;
return (2.0 / z) - 1.0; const auto max = 32768;
const double range = max / 2; // NOLINT
double result = (z % max) / range;
result -= 1;
return result;
} }
template <class T, MIGRAPH_REQUIRES(std::is_signed<T>{} and not std::is_floating_point<T>{})> template <class T, MIGRAPH_REQUIRES(std::is_signed<T>{} and not std::is_floating_point<T>{})>
...@@ -54,11 +58,29 @@ struct xorshf96_generator ...@@ -54,11 +58,29 @@ struct xorshf96_generator
} }
}; };
template <class T>
struct xorshift_generator
{
unsigned long x;
xorshift_generator(unsigned long seed = 0) : x(521288629ULL ^ seed) {}
constexpr T operator()() noexcept
{
x ^= x >> 12U;
x ^= x << 25U;
x ^= x >> 27U;
return normalize<T>(x * 0x2545F4914F6CDD1D);
}
};
template <class T> template <class T>
std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed = 0) std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed = 0)
{ {
std::vector<T> result(s.elements()); std::vector<T> result(s.elements());
std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed}); std::generate(result.begin(), result.end(), xorshf96_generator<T>{seed});
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; });
return result; return result;
} }
......
...@@ -115,6 +115,11 @@ struct instruction ...@@ -115,6 +115,11 @@ struct instruction
} }
shape get_shape() const { return result; } shape get_shape() const { return result; }
const literal& get_literal() const
{
assert(op.name() == "@literal");
return lit;
}
friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; } friend bool operator==(instruction_ref ref, const instruction& i) { return i == ref; }
......
...@@ -2,17 +2,10 @@ ...@@ -2,17 +2,10 @@
#define MIGRAPH_GUARD_RTGLIB_TRACER_HPP #define MIGRAPH_GUARD_RTGLIB_TRACER_HPP
#include <ostream> #include <ostream>
#include <migraph/functional.hpp>
namespace migraph { namespace migraph {
struct swallow
{
template <class... Ts>
swallow(Ts&&...)
{
}
};
struct tracer struct tracer
{ {
tracer() {} tracer() {}
......
#ifndef MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#define MIGRAPH_GUARD_RTGLIB_VERIFY_ARGS_HPP
#include <migraph/verify.hpp>
#include <migraph/argument.hpp>
namespace migraph {
inline void verify_args(const std::string& name,
const argument& cpu_arg,
const argument& gpu_arg,
double tolerance = 80)
{
visit_all(cpu_arg, gpu_arg)([&](auto cpu, auto gpu) {
if(not verify_range(cpu, gpu, tolerance))
{
// TODO: Check for nans
std::cout << "FAILED: " << name << std::endl;
if(cpu.size() < 32)
std::cout << "cpu:" << cpu << std::endl;
if(gpu.size() < 32)
std::cout << "gpu:" << gpu << std::endl;
if(range_zero(cpu))
std::cout << "Cpu data is all zeros" << std::endl;
if(range_zero(gpu))
std::cout << "Gpu data is all zeros" << std::endl;
auto idx = mismatch_idx(cpu, gpu, float_equal);
if(idx < range_distance(cpu))
{
std::cout << "Mismatch at " << idx << ": " << cpu[idx] << " != " << gpu[idx]
<< std::endl;
}
auto cpu_nan_idx = find_idx(cpu, not_finite);
if(cpu_nan_idx >= 0)
std::cout << "Non finite number found in cpu at " << cpu_nan_idx << ": "
<< cpu[cpu_nan_idx] << std::endl;
auto gpu_nan_idx = find_idx(gpu, not_finite);
if(gpu_nan_idx >= 0)
std::cout << "Non finite number found in gpu at " << gpu_nan_idx << ": "
<< gpu[gpu_nan_idx] << std::endl;
std::cout << std::endl;
}
});
}
} // namespace migraph
#endif
...@@ -17,11 +17,15 @@ rocm_clang_tidy_check(read_onnx) ...@@ -17,11 +17,15 @@ rocm_clang_tidy_check(read_onnx)
target_link_libraries(read_onnx migraph_onnx) target_link_libraries(read_onnx migraph_onnx)
if(MIGRAPH_ENABLE_GPU)
add_executable(mnist mnist.cpp) add_executable(mnist mnist.cpp)
rocm_clang_tidy_check(mnist) rocm_clang_tidy_check(mnist)
target_link_libraries(mnist migraph_cpu migraph_onnx) target_link_libraries(mnist migraph_cpu migraph_gpu migraph_onnx)
add_executable(cifar10 cifar10.cpp)
rocm_clang_tidy_check(cifar10)
target_link_libraries(cifar10 migraph_cpu migraph_gpu migraph_onnx)
if(MIGRAPH_ENABLE_GPU)
add_executable(verify_onnx verify_onnx.cpp) add_executable(verify_onnx verify_onnx.cpp)
rocm_clang_tidy_check(verify_onnx) rocm_clang_tidy_check(verify_onnx)
target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_gpu) target_link_libraries(verify_onnx migraph_onnx migraph_cpu migraph_gpu)
......
#include <cstdio>
#include <string>
#include <fstream>
#include <numeric>
#include <stdexcept>
#include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp>
#include "softmax.hpp"
auto read_cifar10_images(const std::string& full_path)
{
std::ifstream file(full_path, std::ios::binary);
const size_t nimages = 10;
const size_t nbytes_per_image = 3072;
std::vector<uint8_t> raw_data(nimages * (nbytes_per_image + 1));
std::vector<uint8_t> labels(nimages);
std::vector<float> data(nimages * nbytes_per_image);
if(file.is_open())
{
file.read(reinterpret_cast<char*>(raw_data.data()),
(nbytes_per_image + 1) * nimages * sizeof(uint8_t));
uint8_t* pimage = raw_data.data();
for(size_t i = 0; i < nimages; i++, pimage += nbytes_per_image)
{
labels[i] = *pimage++;
for(size_t j = 0; j < nbytes_per_image; j++)
{
float v = *(pimage + j) / 255.0f;
data[i * nbytes_per_image + j] = v;
}
}
return std::make_pair(labels, data);
}
else
{
throw std::runtime_error("Cannot open file `" + full_path + "`!");
}
}
int main(int argc, char const* argv[])
{
if(argc < 4)
{
throw std::runtime_error("Usage: cifar10 [gpu | cpu] <onnx file> <cifar10 data file>");
}
std::string gpu_cpu = argv[1];
std::string file = argv[2];
std::string datafile = argv[3];
auto prog = migraph::parse_onnx(file);
std::cout << prog << std::endl;
auto imageset = read_cifar10_images(datafile);
if(gpu_cpu == "gpu")
{
// GPU target
prog.compile(migraph::gpu::target{});
migraph::program::parameter_map m;
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
for(auto&& x : prog.get_parameter_shapes())
{
m[x.first] = migraph::gpu::to_gpu(migraph::generate_argument(x.second));
}
auto labels = imageset.first;
auto input = imageset.second;
auto ptr = input.data();
for(int i = 0; i < 10; i++)
{
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
m["0"] = migraph::gpu::to_gpu(migraph::argument{s, &ptr[3072 * i]});
auto result = migraph::gpu::from_gpu(prog.eval(m));
std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax<float>(logits);
for(auto x : probs)
std::cout << x << " ";
std::cout << std::endl << std::endl;
}
}
else
{
// CPU target
prog.compile(migraph::cpu::cpu_target{});
auto s = migraph::shape{migraph::shape::float_type, {1, 3, 32, 32}};
auto labels = imageset.first;
auto input = imageset.second;
auto ptr = input.data();
for(int i = 0; i < 10; i++)
{
std::cout << "label: " << static_cast<uint32_t>(labels[i]) << " ----> ";
auto input3 = migraph::argument{s, &ptr[3072 * i]};
auto result = prog.eval({{"0", input3}});
std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax<float>(logits);
for(auto x : probs)
std::cout << x << " ";
std::cout << std::endl;
}
}
}
...@@ -6,9 +6,12 @@ ...@@ -6,9 +6,12 @@
#include <migraph/onnx.hpp> #include <migraph/onnx.hpp>
#include <migraph/cpu/cpu_target.hpp> #include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraph/generate.hpp>
#include "softmax.hpp"
auto reverse_int(unsigned int i) auto reverse_int(unsigned int i)
{ {
unsigned char c1, c2, c3, c4; unsigned char c1, c2, c3, c4;
...@@ -97,16 +100,6 @@ std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number ...@@ -97,16 +100,6 @@ std::vector<int32_t> read_mnist_labels(const std::string& full_path, int& number
} }
} }
std::vector<float> softmax(std::vector<float> p)
{
size_t n = p.size();
std::vector<float> result(n);
std::transform(p.begin(), p.end(), result.begin(), [](auto x) { return std::exp(x); });
float s = std::accumulate(result.begin(), result.end(), 0.0f, std::plus<float>());
std::transform(result.begin(), result.end(), result.begin(), [=](auto x) { return x / s; });
return result;
}
int main(int argc, char const* argv[]) int main(int argc, char const* argv[])
{ {
if(argc > 3) if(argc > 3)
...@@ -121,15 +114,19 @@ int main(int argc, char const* argv[]) ...@@ -121,15 +114,19 @@ int main(int argc, char const* argv[])
std::string file = argv[1]; std::string file = argv[1];
auto prog = migraph::parse_onnx(file); auto prog = migraph::parse_onnx(file);
prog.compile(migraph::cpu::cpu_target{}); std::cout << prog << std::endl << std::endl;
prog.compile(migraph::gpu::target{});
auto s = migraph::shape{migraph::shape::float_type, {1, 1, 28, 28}}; auto s = migraph::shape{migraph::shape::float_type, {1, 1, 28, 28}};
std::cout << s << std::endl; std::cout << s << std::endl;
auto ptr = input.data(); auto ptr = input.data();
migraph::program::parameter_map m;
m["output"] =
migraph::gpu::to_gpu(migraph::generate_argument(prog.get_parameter_shape("output")));
for(int i = 0; i < 20; i++) for(int i = 0; i < 20; i++)
{ {
std::cout << "label: " << labels[i] << " ----> "; std::cout << "label: " << labels[i] << " ----> ";
auto input3 = migraph::argument{s, &ptr[784 * i]}; m["0"] = migraph::gpu::to_gpu(migraph::argument{s, &ptr[784 * i]});
auto result = prog.eval({{"Input3", input3}}); auto result = migraph::gpu::from_gpu(prog.eval(m));
std::vector<float> logits; std::vector<float> logits;
result.visit([&](auto output) { logits.assign(output.begin(), output.end()); }); result.visit([&](auto output) { logits.assign(output.begin(), output.end()); });
std::vector<float> probs = softmax(logits); std::vector<float> probs = softmax(logits);
......
...@@ -234,7 +234,7 @@ struct onnx_parser ...@@ -234,7 +234,7 @@ struct onnx_parser
} }
if(contains(attributes, "momentum")) if(contains(attributes, "momentum"))
{ {
epsilon = parse_value(attributes.at("momentum")).at<float>(); momentum = parse_value(attributes.at("momentum")).at<float>();
} }
if(contains(attributes, "is_test")) if(contains(attributes, "is_test"))
{ {
......
#include <vector>
#include <algorithm>
#include <cmath>
template <typename T>
std::vector<T> softmax(const std::vector<T>& p)
{
size_t n = p.size();
std::vector<T> result(n);
std::transform(p.begin(), p.end(), result.begin(), [](auto x) { return std::exp(x); });
T s = std::accumulate(result.begin(), result.end(), 0.0f, std::plus<T>());
std::transform(result.begin(), result.end(), result.begin(), [=](auto x) { return x / s; });
return result;
}
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include <migraph/gpu/target.hpp> #include <migraph/gpu/target.hpp>
#include <migraph/gpu/hip.hpp> #include <migraph/gpu/hip.hpp>
#include <migraph/generate.hpp> #include <migraph/generate.hpp>
#include <migraph/verify.hpp> #include <migraph/verify_args.hpp>
migraph::argument run_cpu(const std::string& file) migraph::argument run_cpu(const std::string& file)
{ {
...@@ -46,20 +46,6 @@ int main(int argc, char const* argv[]) ...@@ -46,20 +46,6 @@ int main(int argc, char const* argv[])
auto x = run_cpu(file); auto x = run_cpu(file);
auto y = run_gpu(file); auto y = run_gpu(file);
visit_all(x, y)([](auto cpu, auto gpu) { migraph::verify_args(file, x, y, 100);
if(migraph::verify_range(cpu, gpu, 100))
{
std::cout << "Passed" << std::endl;
}
else
{
std::cout << "Not equal" << std::endl;
std::cout << "cpu:" << std::endl;
std::cout << cpu << std::endl;
std::cout << "gpu:" << std::endl;
std::cout << gpu << std::endl;
}
});
} }
} }
...@@ -11,6 +11,7 @@ if(NOT TARGET MIOpen) ...@@ -11,6 +11,7 @@ if(NOT TARGET MIOpen)
endif() endif()
add_library(migraph_device add_library(migraph_device
device/add.cpp
device/add_relu.cpp device/add_relu.cpp
device/contiguous.cpp device/contiguous.cpp
) )
......
#include <migraph/gpu/device/add.hpp>
#include <migraph/gpu/device/nary.hpp>
namespace migraph {
namespace gpu {
namespace device {
void add(const argument& result, const argument& arg1, const argument& arg2)
{
nary(result, arg1, arg2)([](auto x, auto y) { return x + y; });
}
} // namespace device
} // namespace gpu
} // namespace migraph
...@@ -5,10 +5,9 @@ namespace migraph { ...@@ -5,10 +5,9 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
void add_relu(argument result, argument arg1, argument arg2) void add_relu(const argument& result, const argument& arg1, const argument& arg2)
{ {
nary_standard(std::move(result), std::move(arg1), std::move(arg2))( nary(result, arg1, arg2)([](auto x, auto y) { return std::max<decltype(x + y)>(0, x + y); });
[](auto x, auto y) { return max(0, x + y); });
} }
} // namespace device } // namespace device
......
...@@ -33,10 +33,10 @@ inline auto launch(std::size_t global, std::size_t local) ...@@ -33,10 +33,10 @@ inline auto launch(std::size_t global, std::size_t local)
}; };
} }
inline auto gs_launch(std::size_t n, std::size_t local = 512) inline auto gs_launch(std::size_t n, std::size_t local = 1024)
{ {
std::size_t groups = 1 + n / local; std::size_t groups = 1 + n / local;
std::size_t nglobal = std::min<std::size_t>(512, groups) * local; std::size_t nglobal = std::min<std::size_t>(256, groups) * local;
return [=](auto f) { return [=](auto f) {
launch(nglobal, local)([=](auto idx) { launch(nglobal, local)([=](auto idx) {
...@@ -48,6 +48,14 @@ inline auto gs_launch(std::size_t n, std::size_t local = 512) ...@@ -48,6 +48,14 @@ inline auto gs_launch(std::size_t n, std::size_t local = 512)
}; };
} }
// Workaround hcc's broken tile_static macro
#ifdef tile_static
#undef tile_static
#define MIGRAPH_DEVICE_SHARED __attribute__((tile_static))
#else
#define MIGRAPH_DEVICE_SHARED __shared__
#endif
} // namespace device } // namespace device
} // namespace gpu } // namespace gpu
} // namespace migraph } // namespace migraph
......
...@@ -10,16 +10,25 @@ namespace migraph { ...@@ -10,16 +10,25 @@ namespace migraph {
namespace gpu { namespace gpu {
namespace device { namespace device {
template <class... Arguments> template <class T>
auto nary(argument result, Arguments... args) using vec4 = T __attribute__((ext_vector_type(4)));
template <class T>
__device__ __host__ vec4<T>* as_vec4(T* x)
{ {
return [=](auto f) { return reinterpret_cast<vec4<T>*>(x);
if(all_of({args...}, [](const shape& s) { return s.standard(); })) }
nary_standard(result, args...)(f);
else
nary_nonstandard(result, args...)(f);
}; template <class T>
__device__ __host__ T* as_pointer(vec4<T>* x)
{
return reinterpret_cast<T*>(x);
}
template <class... Ts>
auto pack_vec4(Ts... xs)
{
return [=](auto f, std::size_t n) { return f(as_vec4(xs)[n]...); };
} }
template <class F, class... Arguments> template <class F, class... Arguments>
...@@ -28,14 +37,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args) ...@@ -28,14 +37,12 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) { visit_all(result, args...)([&](auto output, auto... inputs) {
visit_tensor_size(output_shape.lens().size(), [&](auto ndim) { visit_tensor_size(output_shape.lens().size(), [&](auto ndim) {
auto data = make_sequence( auto data = pack(
std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape().lens(), std::make_pair(hip_tensor_descriptor<ndim>{inputs.get_shape()}, inputs.data())...);
inputs.get_shape().strides()}, hip_tensor_descriptor<ndim> out_desc(output_shape);
inputs.data())...);
hip_tensor_descriptor<ndim> out_desc(output_shape.lens(), output_shape.strides());
auto* outp = output.data(); auto* outp = output.data();
gs_launch(output_shape.elements())([=](auto i) { gs_launch(output_shape.elements())([=](auto i) {
data([&](auto... ps) { data([&](auto&&... ps) {
auto outidx = out_desc.multi(i); auto outidx = out_desc.multi(i);
outp[i] = f(ps.second[ps.first.linear(outidx)]...); outp[i] = f(ps.second[ps.first.linear(outidx)]...);
}); });
...@@ -44,24 +51,199 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args) ...@@ -44,24 +51,199 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
}); });
} }
template <class F>
void binary_broadcast_vec_impl(F f,
const argument& result,
const argument& arg1,
const argument& arg2)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = as_vec4(input1.data());
auto* yp = as_vec4(input2.data());
auto* outp = as_vec4(output.data());
const std::size_t vec_size = 4;
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size() / vec_size;
const std::size_t bdim_vec_len = bdim_len / vec_size;
launch(nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED vec4<type> buffer[2048 / vec_size];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_vec_len; i += nlocal)
{
buffer[i] = yp[i];
}
__syncthreads();
auto* bp = as_pointer(buffer);
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b = bp[bidx];
vec4<type> x = xp[i];
vec4<type> out = outp[i];
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(x[j], b);
}
outp[i] = out;
}
});
});
}
template <class F>
void binary_broadcast_impl(F f, const argument& result, const argument& arg1, const argument& arg2)
{
const auto& output_shape = result.get_shape();
const auto& b_shape = arg2.get_shape();
auto bdim =
std::distance(b_shape.strides().begin(),
std::find_if(b_shape.strides().begin(), b_shape.strides().end(), [](auto x) {
return x != 0;
}));
auto bdim_len = output_shape.lens()[bdim];
auto bdim_stride = output_shape.strides()[bdim];
auto bdim_next_stride = bdim_stride * bdim_len;
visit_all(result, arg1, arg2)([&](auto output, auto input1, auto input2) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
auto* xp = input1.data();
auto* yp = input2.data();
auto* outp = output.data();
const std::size_t nlocal = 1024;
const std::size_t nglobal = 256 * nlocal;
const std::size_t n = output.size();
launch(nglobal, nlocal)([=](auto idx) __device__ {
MIGRAPH_DEVICE_SHARED type buffer[2048];
// Load bias into LDS
for(size_t i = idx.local; i < bdim_len; i += nlocal)
{
buffer[i] = yp[i];
}
__syncthreads();
// Process the data
for(size_t i = idx.global; i < n; i += nglobal)
{
auto bidx = (i % bdim_next_stride) / bdim_stride;
auto b = buffer[bidx];
type x = xp[i];
outp[i] = f(x, b);
}
});
});
}
template <class F, class... Arguments>
void nary_standard_vec_impl(F f, argument result, Arguments... args)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
using type = std::remove_cv_t<typename decltype(output)::value_type>;
const std::size_t vec_size = 4;
auto data = pack_vec4(inputs.data()...);
auto* outp = as_vec4(output.data());
gs_launch(output_shape.elements() / vec_size)([=](auto i) {
vec4<type> out = outp[i];
data(
[&](auto... xs) {
for(std::size_t j = 0; j < vec_size; j++)
{
out[j] = f(xs[j]...);
}
},
i);
outp[i] = out;
});
});
}
template <class F, class... Arguments>
void nary_standard_impl(F f, argument result, Arguments... args)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
const auto& output_shape = result.get_shape();
visit_all(result, args...)([&](auto output, auto... inputs) {
auto data = pack(inputs.data()...);
auto* outp = output.data();
gs_launch(output_shape.elements())(
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); });
});
}
template <class F, class... Arguments>
void nary_impl(F f, argument result, Arguments... args)
{
bool standard = all_of({args.get_shape()...}, [](const shape& s) { return s.standard(); });
bool packed = all_of({args.get_shape()...}, [](const shape& s) { return s.packed(); });
bool same_shapes =
all_of({args.get_shape()...}, [&](const shape& s) { return s == result.get_shape(); });
if(standard or (packed and same_shapes))
nary_standard_impl(f, result, args...);
else
nary_nonstandard_impl(f, result, args...);
}
template <class... Arguments> template <class... Arguments>
auto nary_nonstandard(argument result, Arguments... args) auto nary_nonstandard(argument result, Arguments... args)
{ {
return [=](auto f) { return nary_nonstandard_impl(f, result, args...); }; return [=](auto f) { nary_nonstandard_impl(f, result, args...); };
} }
template <class... Arguments> template <class... Arguments>
auto nary_standard(argument result, Arguments... args) auto nary_standard(argument result, Arguments... args)
{
return [=](auto f) { nary_standard_impl(f, result, args...); };
}
template <class... Arguments>
auto nary(argument result, Arguments... args)
{
return [=](auto f) { nary_impl(f, result, args...); };
}
inline auto nary(const argument& result, const argument& arg1, const argument& arg2)
{ {
return [=](auto f) { return [=](auto f) {
// assert(x.get_shape().elements() == y.get_shape().elements()); // TODO: Check result and arg1 shape is the same
const auto& output_shape = result.get_shape(); if(arg1.get_shape().standard() and arg2.get_shape().broadcasted())
visit_all(result, args...)([&](auto output, auto... inputs) { {
auto data = make_sequence(inputs.data()...); auto not_zero = [](auto x) { return x != 0; };
auto* outp = output.data(); const auto& strides = arg2.get_shape().strides();
gs_launch(output_shape.elements())( auto b_it = std::find_if(strides.begin(), strides.end(), not_zero);
[=](auto i) { data([&](auto... xps) { outp[i] = f(xps[i]...); }); }); auto b_idx = std::distance(strides.begin(), b_it);
}); auto b_len = result.get_shape().lens()[b_idx];
auto b_stride = result.get_shape().strides()[b_idx];
assert(arg2.get_shape().lens()[b_idx] == b_len);
if(b_len <= 2048 and std::none_of(std::next(b_it), strides.end(), not_zero))
{
const bool divisible_by_4 = (b_len % 4 == 0) and (b_stride % 4 == 0) and
(arg1.get_shape().elements() % 4 == 0);
if(divisible_by_4)
binary_broadcast_vec_impl(f, result, arg1, arg2);
else
binary_broadcast_impl(f, result, arg1, arg2);
return;
}
}
nary_impl(f, result, arg1, arg2);
}; };
} }
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_DEAVICE_TENSOR_HPP #define MIGRAPH_GUARD_RTGLIB_DEAVICE_TENSOR_HPP
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <migraph/functional.hpp>
namespace migraph { namespace migraph {
namespace gpu { namespace gpu {
...@@ -53,14 +54,13 @@ template <size_t NDim> ...@@ -53,14 +54,13 @@ template <size_t NDim>
struct hip_tensor_descriptor struct hip_tensor_descriptor
{ {
__device__ __host__ hip_tensor_descriptor() = default; __device__ __host__ hip_tensor_descriptor() = default;
template <typename T, typename V>
__device__ __host__ hip_tensor_descriptor(const T& lens_ext, const V& strides_ext) hip_tensor_descriptor(const shape& s)
{ {
for(size_t i = 0; i < NDim; i++) std::copy(s.lens().begin(), s.lens().end(), lens);
lens[i] = lens_ext[i]; std::copy(s.strides().begin(), s.strides().end(), strides);
for(size_t i = 0; i < NDim; i++)
strides[i] = strides_ext[i];
} }
__device__ __host__ hip_index<NDim> multi(size_t idx) const __device__ __host__ hip_index<NDim> multi(size_t idx) const
{ {
hip_index<NDim> result{}; hip_index<NDim> result{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment