Commit f9a06df3 authored by Cagri Eryilmaz's avatar Cagri Eryilmaz
Browse files

Merge branch 'develop' into unet

parents 07189c21 0b04fc80
#include <migraphx/config.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/context.hpp>
#include <migraphx/cpu/context.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
struct cpu_preallocate : auto_register_op<cpu_preallocate>
{
shape s;
std::string id = "";
argument data;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.s, "shape"), f(self.id, "id"));
}
std::string name() const { return "cpu::preallocate"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(0);
return s;
}
argument compute(context&, const shape&, const std::vector<argument>&) const { return data; }
void finalize(context&, const shape&, const std::vector<shape>&) { data = argument(s); }
lifetime get_lifetime() const { return lifetime::global; }
};
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -22,6 +22,7 @@
#include <migraphx/memory_coloring.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/cpu/fuse_ops.hpp>
#include <migraphx/cpu/write_literals.hpp>
#include <migraphx/cpu/allocation_model.hpp>
......@@ -76,6 +77,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
write_literals{},
dead_code_elimination{},
memory_coloring{"cpu::allocate"},
dead_code_elimination{},
preallocate_param{"scratch", cpu_allocation_model{}},
dead_code_elimination{}};
}
......
......@@ -61,6 +61,7 @@ add_library(migraphx_device
device/pad.cpp
device/pow.cpp
device/prelu.cpp
device/prefix_scan_sum.cpp
device/recip.cpp
device/reduce_max.cpp
device/reduce_mean.cpp
......@@ -121,6 +122,7 @@ add_library(migraphx_gpu
convert.cpp
convolution.cpp
deconvolution.cpp
device_name.cpp
eliminate_workspace.cpp
elu.cpp
fuse_ops.cpp
......@@ -139,7 +141,6 @@ add_library(migraphx_gpu
pack_int8_args.cpp
pad.cpp
pooling.cpp
preallocate_param.cpp
quant_convolution.cpp
reverse.cpp
rnn_variable_seq_lens.cpp
......@@ -192,6 +193,7 @@ register_migraphx_gpu_ops(hip_
pad
pow
prelu
prefix_scan_sum
recip
reduce_max
reduce_mean
......
......@@ -11,6 +11,11 @@ operation gpu_allocation_model::allocate(const shape& s) const
return make_op(name(), {{"shape", to_value(s)}});
}
operation gpu_allocation_model::preallocate(const shape& s, const std::string& id) const
{
return make_op("hip::hip_allocate_memory", {{"shape", to_value(s)}, {"id", id}});
}
std::string gpu_allocation_model::copy() const { return "hip::copy"; }
} // namespace gpu
......
......@@ -2,9 +2,9 @@
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/context.hpp>
#include <migraphx_kernels.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>
......@@ -12,36 +12,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <class HipDeviceProp>
std::string get_arch_name(rank<0>, const HipDeviceProp& props)
{
return "gfx" + std::to_string(props.gcnArch);
}
template <class HipDeviceProp>
auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(props.gcnArchName))
{
return std::string(props.gcnArchName);
}
int get_device_id()
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
MIGRAPHX_THROW("No device");
return device;
}
std::string get_device_name()
{
hipDeviceProp_t props{};
auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(rank<1>{}, props);
}
template <class T>
std::string generate_index_ints(const std::vector<T>& v)
{
......
......@@ -5,6 +5,7 @@
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/visit.hpp>
#include <migraphx/gpu/device/multi_index.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -15,79 +16,6 @@ namespace device {
#define MIGRAPHX_NO_DPP
#endif
struct sum
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x + y;
}
};
struct product
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x * y;
}
};
struct id
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x;
}
};
struct mean
{
size_t item_num = 1;
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x / static_cast<T>(item_num);
}
};
struct max
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return (x > y) ? x : y;
}
};
struct min
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return (x < y) ? x : y;
}
};
struct lowest
{
template <class T>
__device__ __host__ operator T() const
{
return device_cast(std::numeric_limits<host_type<T>>::lowest());
}
};
struct highest
{
template <class T>
__device__ __host__ operator T() const
{
return device_cast(std::numeric_limits<host_type<T>>::max());
}
};
#ifdef MIGRAPHX_NO_DPP
template <index_int N,
class Op,
......
#ifndef MIGRAPHX_GUARD_DEVICE_REDUCE_OPS_HPP
#define MIGRAPHX_GUARD_DEVICE_REDUCE_OPS_HPP
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
struct sum
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x + y;
}
};
struct product
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return x * y;
}
};
struct id
{
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x;
}
};
struct mean
{
size_t item_num = 1;
template <class T>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x) const
{
return x / static_cast<T>(item_num);
}
};
struct max
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return (x > y) ? x : y;
}
};
struct min
{
template <class T, class U>
MIGRAPHX_DEVICE_CONSTEXPR auto operator()(T x, U y) const
{
return (x < y) ? x : y;
}
};
struct lowest
{
template <class T>
__device__ __host__ operator T() const
{
return device_cast(std::numeric_limits<host_type<T>>::lowest());
}
};
struct highest
{
template <class T>
__device__ __host__ operator T() const
{
return device_cast(std::numeric_limits<host_type<T>>::max());
}
};
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_REDUCE_OPS_HPP
#ifndef MIGRAPHX_GUARD_DEVICE_SCAN_HPP
#define MIGRAPHX_GUARD_DEVICE_SCAN_HPP
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/visit.hpp>
#include <migraphx/gpu/device/multi_index.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
template <index_int N,
class Op,
class T,
class ForStride,
class Input,
class Output,
MIGRAPHX_REQUIRES(not std::is_integral<ForStride>{})>
__device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input, Output output)
{
using type = decltype(input(deduce_for_stride(fs)));
MIGRAPHX_DEVICE_SHARED type buffer[N];
type x = init;
fs([&](auto i) {
if(idx.local == 0)
buffer[idx.local] = op(input(i), x);
else
buffer[idx.local] = input(i);
__syncthreads();
for(index_int s = 1; s < idx.nlocal(); s *= 2)
{
if(idx.local + s < idx.nlocal())
{
buffer[idx.local + s] = op(buffer[idx.local], buffer[idx.local + s]);
}
__syncthreads();
}
x = buffer[idx.nlocal() - 1];
output(i, buffer[idx.local]);
});
}
template <index_int N, class Op, class T, class Input, class Output>
__device__ void block_scan(index idx, Op op, T init, index_int n, Input input, Output output)
{
block_scan<N>(idx,
op,
init,
[&](auto f) -> decltype(f(index_int{})) { return idx.local_stride(n, f); },
input,
output);
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_SCAN_HPP
#include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/gpu/device/scan.hpp>
#include <migraphx/gpu/device/reduce_ops.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& arg, int32_t axis)
{
const index_int block_size = 256;
const index_int n = arg.get_shape().lens()[axis];
auto rlens = result.get_shape().lens();
rlens[axis] = 1;
hip_visit_all(result, arg, result.get_shape().with_lens(rlens))(
[=](auto output, auto input, auto rshape) {
gs_launch(stream, rshape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
const auto ridx = rshape.multi(i / block_size);
auto compute_idx = [&](auto j) {
auto k = ridx;
k[axis] = j;
return k;
};
block_scan<block_size>(idx,
sum{},
0,
n,
[&](auto j) { return input[compute_idx(j)]; },
[&](auto j, auto x) { output[compute_idx(j)] = x; });
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/errors.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/stringutils.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
template <class HipDeviceProp>
std::string get_arch_name(rank<0>, const HipDeviceProp& props)
{
return "gfx" + std::to_string(props.gcnArch);
}
template <class HipDeviceProp>
auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(props.gcnArchName))
{
return std::string(props.gcnArchName);
}
int get_device_id()
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
MIGRAPHX_THROW("No device");
return device;
}
std::string get_device_name()
{
hipDeviceProp_t props{};
auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(rank<1>{}, props);
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -71,7 +71,7 @@ void gemm_impl(context& ctx,
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
rocblas_gemm_flags flag = rocblas_gemm_flags_none;
int flag = 0;
#endif
auto a_lens = args[0].get_shape().lens();
......
......@@ -14,6 +14,7 @@ struct gpu_allocation_model
std::string name() const;
std::string copy() const;
operation allocate(const shape& s) const;
operation preallocate(const shape& s, const std::string& id) const;
};
} // namespace gpu
......
#ifndef MIGRAPHX_GUARD_DEVICE_PREFIX_SCAN_SUM_HPP
#define MIGRAPHX_GUARD_DEVICE_PREFIX_SCAN_SUM_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& arg, int32_t axis);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_PREFIX_SCAN_SUM_HPP
#ifndef MIGRAPHX_GUARD_GPU_DEVICE_NAME_HPP
#define MIGRAPHX_GUARD_GPU_DEVICE_NAME_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
std::string get_device_name();
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_DEVICE_NAME_HPP
#ifndef MIGRAPHX_GUARD_GPU_PREFIX_SCAN_SUM_HPP
#define MIGRAPHX_GUARD_GPU_PREFIX_SCAN_SUM_HPP
#include <migraphx/gpu/name.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/prefix_scan_sum.hpp>
#include <migraphx/op/prefix_scan_sum.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/type_name.hpp>
#include <utility>
#include <iostream>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_prefix_scan_sum : oper<hip_prefix_scan_sum>
{
op::prefix_scan_sum op;
template <class Self, class T>
static auto reflect(Self& self, T f)
{
return migraphx::reflect(self.op, f);
}
shape compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> in_shapes{inputs};
in_shapes.pop_back();
check_shapes{in_shapes, *this}.standard();
return op.normalize_compute_shape(in_shapes);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
if(op.exclusive or op.reverse)
MIGRAPHX_THROW("Exclusive and reverse scan not supported");
device::prefix_scan_sum(ctx.get_stream().get(), args[1], args[0], op.axis);
return args[1];
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_PREFIX_SCAN_SUM_HPP
......@@ -163,6 +163,7 @@ struct miopen_apply
add_extend_op("lrn");
add_extend_op("pad");
add_extend_op("pooling");
add_extend_op("prefix_scan_sum");
add_extend_op("reduce_max");
add_extend_op("reduce_mean");
add_extend_op("reduce_min");
......
......@@ -14,6 +14,7 @@
#include <migraphx/insert_pad.hpp>
#include <migraphx/memory_coloring.hpp>
#include <migraphx/normalize_ops.hpp>
#include <migraphx/preallocate_param.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/remap.hpp>
......@@ -31,7 +32,6 @@
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/mlir_conv.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/preallocate_param.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/sync_device.hpp>
#include <migraphx/gpu/target.hpp>
......@@ -98,7 +98,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
schedule{gpu::schedule_model{ctx.get_current_device().nstreams()}, not enabled(MIGRAPHX_DISABLE_SCHEDULE_PASS{})},
memory_coloring{"hip::allocate"},
sync_device{},
preallocate_param{"scratch", &ctx},
preallocate_param{"scratch", gpu_allocation_model{}},
dead_code_elimination{},
eliminate_workspace{},
eliminate_allocation{"hip::allocate"},
......
......@@ -17,6 +17,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/config.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/common.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tf/tf_parser.hpp>
......@@ -74,66 +75,11 @@ instruction_ref tf_parser::node_info::make_contiguous(instruction_ref ins) const
return mm->add_instruction(make_op("contiguous"), ins);
}
std::vector<std::size_t> compute_broadcasted_lens(std::vector<std::size_t> s0,
std::vector<std::size_t> s1)
{
// Example:
// s0 = (3,2,4,5) and s1 = (2,1,1)
//
// In this case we need to broadcast (:,1,1) portion of
// s1 plus broadcast the 1st dimension of s1
// giving output_lens = (3,2,4,5)
//
// Another example:
// s0 = (3,2,1,5) and s1 = (2,7,5)
// In this case we need to broadcast the (:,:,1:,:) axis
// of s0 plus the 1st dimension of s1 giving
// output_lens = (3,2,7,5)
if(s0.size() > s1.size())
{
s0.swap(s1);
}
std::vector<std::size_t> out_lens(s1);
auto offset = s1.size() - s0.size();
std::transform(
s0.begin(), s0.end(), s1.begin() + offset, out_lens.begin() + offset, [&](auto a, auto b) {
if(a != b and a != 1 and b != 1)
{
MIGRAPHX_THROW("COMPUTE_BROADCASTLEN: shape {" + to_string_range(s0) + "} and {" +
to_string_range(s1) + "} mismatch!");
}
return std::max(a, b);
});
return out_lens;
}
instruction_ref tf_parser::node_info::add_broadcastable_binary_op(const std::string& op_name,
instruction_ref arg0,
instruction_ref arg1) const
{
if(arg0->get_shape().lens() != arg1->get_shape().lens())
{
// Get lengths for both arguments
auto s0 = arg0->get_shape().lens();
auto s1 = arg1->get_shape().lens();
auto out_lens = compute_broadcasted_lens(s0, s1);
auto l0 = arg0;
if(arg0->get_shape().lens() != out_lens)
l0 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg0);
auto l1 = arg1;
if(arg1->get_shape().lens() != out_lens)
l1 = add_instruction(make_op("multibroadcast", {{"output_lens", out_lens}}), arg1);
return add_instruction(make_op(op_name), l0, l1);
}
else
{
return add_instruction(make_op(op_name), {arg0, arg1});
}
return add_common_op(*mm, make_op(op_name), {arg0, arg1});
}
int64_t tf_parser::parse_axis(const int64_t dim, const size_t num_dims) const
......
......@@ -6,6 +6,7 @@
#include <migraphx/gpu/kernel.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
......@@ -74,19 +75,10 @@ migraphx::src_file make_src_file(const std::string& name, const std::string& con
return {name, std::make_pair(content.data(), content.data() + content.size())};
}
std::string get_device_name()
{
hipDeviceProp_t props{};
int device;
EXPECT(hipGetDevice(&device) == hipSuccess);
EXPECT(hipGetDeviceProperties(&props, device) == hipSuccess);
return "gfx" + std::to_string(props.gcnArch);
}
TEST_CASE(simple_compile_hip)
{
auto binaries = migraphx::gpu::compile_hip_src(
{make_src_file("main.cpp", write_2s)}, "", get_device_name());
{make_src_file("main.cpp", write_2s)}, "", migraphx::gpu::get_device_name());
EXPECT(binaries.size() == 1);
migraphx::argument input{{migraphx::shape::int8_type, {5}}};
......@@ -103,7 +95,7 @@ TEST_CASE(simple_compile_hip)
TEST_CASE(code_object_hip)
{
auto binaries = migraphx::gpu::compile_hip_src(
{make_src_file("main.cpp", add_2s_binary)}, "", get_device_name());
{make_src_file("main.cpp", add_2s_binary)}, "", migraphx::gpu::get_device_name());
EXPECT(binaries.size() == 1);
migraphx::shape input{migraphx::shape::int8_type, {5}};
......
......@@ -9,6 +9,10 @@
#include <unordered_map>
#include <vector>
#ifdef __linux__
#include <unistd.h>
#endif
#ifndef MIGRAPHX_GUARD_TEST_TEST_HPP
#define MIGRAPHX_GUARD_TEST_TEST_HPP
......@@ -79,8 +83,8 @@ struct function
}
};
template <class Iterator>
inline std::ostream& stream_range(std::ostream& s, Iterator start, Iterator last)
template <class Stream, class Iterator>
inline Stream& stream_range(Stream& s, Iterator start, Iterator last)
{
if(start != last)
{
......@@ -90,22 +94,15 @@ inline std::ostream& stream_range(std::ostream& s, Iterator start, Iterator last
return s;
}
inline std::ostream& operator<<(std::ostream& s, std::nullptr_t)
template <class Stream>
inline Stream& operator<<(Stream& s, std::nullptr_t)
{
s << "nullptr";
return s;
}
template <class T>
inline std::ostream& operator<<(std::ostream& s, const std::vector<T>& v)
{
s << "{ ";
stream_range(s, v.begin(), v.end());
s << "}";
return s;
}
inline std::ostream& operator<<(std::ostream& s, const std::vector<bool>& v)
template <class Stream, class Range>
inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.begin(), v.end()))
{
s << "{ ";
stream_range(s, v.begin(), v.end());
......@@ -264,6 +261,32 @@ struct capture
}
};
enum class color
{
reset = 0,
bold = 1,
underlined = 4,
fg_red = 31,
fg_green = 32,
fg_yellow = 33,
fg_blue = 34,
fg_default = 39,
bg_red = 41,
bg_green = 42,
bg_yellow = 43,
bg_blue = 44,
bg_default = 49
};
inline std::ostream& operator<<(std::ostream& os, const color& c)
{
#ifndef _WIN32
static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m";
#endif
return os;
}
template <class T, class F>
void failed(T x, const char* msg, const char* func, const char* file, int line, F f)
{
......@@ -271,7 +294,7 @@ void failed(T x, const char* msg, const char* func, const char* file, int line,
{
std::cout << func << std::endl;
std::cout << file << ":" << line << ":" << std::endl;
std::cout << " FAILED: " << msg << " "
std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " "
<< "[ " << x << " ]" << std::endl;
f();
}
......@@ -315,7 +338,7 @@ auto near(T px, U py, double ptol = 1e-6f)
using string_map = std::unordered_map<std::string, std::vector<std::string>>;
template <class Keyword>
string_map parse(std::vector<std::string> as, Keyword keyword)
string_map generic_parse(std::vector<std::string> as, Keyword keyword)
{
string_map result;
......@@ -331,19 +354,22 @@ string_map parse(std::vector<std::string> as, Keyword keyword)
{
flag = f.front();
result[flag]; // Ensure the flag exists
flag = f.back();
}
}
return result;
}
using test_case = std::function<void()>;
inline auto& get_test_cases()
{
// NOLINTNEXTLINE
static std::vector<std::pair<std::string, std::function<void()>>> cases;
static std::vector<std::pair<std::string, test_case>> cases;
return cases;
}
inline void add_test_case(std::string name, std::function<void()> f)
inline void add_test_case(std::string name, test_case f)
{
get_test_cases().emplace_back(std::move(name), std::move(f));
}
......@@ -357,37 +383,243 @@ struct auto_register_test_case
}
};
inline void run_test_case(const std::string& name, const std::function<void()>& f)
struct failure_error
{
std::cout << "[ RUN ] " << name << std::endl;
f();
std::cout << "[ COMPLETE ] " << name << std::endl;
}
};
inline void run(int argc, const char* argv[])
[[noreturn]] inline void fail() { throw failure_error{}; }
struct driver
{
std::vector<std::string> as(argv + 1, argv + argc);
driver()
{
add_flag({"--help", "-h"}, "Show help");
add_flag({"--list", "-l"}, "List all test cases");
add_flag({"--continue", "-c"}, "Continue after failure");
add_flag({"--quiet", "-q"}, "Don't print out extra output");
}
struct argument
{
std::vector<std::string> flags = {};
std::string help = "";
int nargs = 1;
};
auto args = parse(as, [](auto &&) -> std::vector<std::string> { return {}; });
auto cases = args[""];
if(cases.empty())
void add_arg(const std::vector<std::string>& flags, const std::string& help = "")
{
for(auto&& tc : get_test_cases())
run_test_case(tc.first, tc.second);
arguments.push_back(argument{flags, help, 1});
}
else
void add_flag(const std::vector<std::string>& flags, const std::string& help = "")
{
std::unordered_map<std::string, std::function<void()>> m(get_test_cases().begin(),
get_test_cases().end());
for(auto&& name : cases)
arguments.push_back(argument{flags, help, 0});
}
void show_help(const std::string& exe) const
{
std::cout << std::endl;
std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl;
std::cout << " ";
std::cout << exe << " <test-case>... <options>" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "ARGS:" << color::reset << std::endl;
std::cout << " ";
std::cout << color::fg_green << "<test-case>..." << color::reset;
std::cout << std::endl;
std::cout << " "
<< "Test case name to run" << std::endl;
std::cout << std::endl;
std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl;
for(auto&& arg : arguments)
{
auto f = m.find(name);
if(f == m.end())
std::cout << "[ ERROR ] Test case '" << name << "' not found." << std::endl;
std::string prefix = " ";
std::cout << color::fg_green;
for(const std::string& a : arg.flags)
{
std::cout << prefix;
std::cout << a;
prefix = ", ";
}
std::cout << color::reset << std::endl;
std::cout << " " << arg.help << std::endl;
}
}
std::ostream& out() const
{
struct null_buffer : std::streambuf
{
virtual int overflow(int c) override { return c; }
};
static null_buffer buffer;
static std::ostream null_stream(&buffer);
if(quiet)
return null_stream;
return std::cout;
}
string_map parse(int argc, const char* argv[]) const
{
std::vector<std::string> args(argv + 1, argv + argc);
string_map keys;
for(auto&& arg : arguments)
{
for(auto&& flag : arg.flags)
{
keys[flag] = {arg.flags.front()};
if(arg.nargs == 0)
keys[flag].push_back("");
}
}
auto result = generic_parse(args, [&](auto&& s) -> std::vector<std::string> {
if(keys.count(s) > 0)
return keys[s];
else
run_test_case(name, f->second);
return {};
});
result["__exe__"].push_back(argv[0]);
return result;
}
static std::string create_command(const string_map& args)
{
std::stringstream ss;
ss << args.at("__exe__").front();
if(args.count("") > 0)
{
for(auto&& arg : args.at(""))
ss << " \"" << arg << "\"";
}
for(auto&& p : args)
{
if(p.first == "__exe__")
continue;
if(p.first.empty())
continue;
ss << " " << p.first;
for(auto&& arg : p.second)
ss << " \"" << arg << "\"";
}
return ss.str();
}
static std::string fork(const std::string& name, string_map args)
{
std::string msg;
args[""] = {name};
args.erase("--continue");
args["--quiet"];
auto cmd = create_command(args);
auto r = std::system(cmd.c_str()); // NOLINT
if(r != 0)
msg = "Exited with " + std::to_string(r);
return msg;
}
void run_test_case(const std::string& name, const test_case& f, const string_map& args)
{
ran++;
out() << color::fg_green << "[ RUN ] " << color::reset << color::bold << name
<< color::reset << std::endl;
std::string msg;
if(args.count("--continue") > 0)
{
msg = fork(name, args);
}
else
{
try
{
f();
}
catch(const failure_error&)
{
msg = "Test failure";
}
}
if(msg.empty())
{
out() << color::fg_green << "[ COMPLETE ] " << color::reset << color::bold << name
<< color::reset << std::endl;
}
else
{
failed.push_back(name);
out() << color::fg_red << "[ FAILED ] " << color::reset << color::bold << name
<< color::reset << ": " << color::fg_yellow << msg << color::reset << std::endl;
}
}
void run(int argc, const char* argv[])
{
auto args = parse(argc, argv);
if(args.count("--help") > 0)
{
show_help(args.at("__exe__").front());
return;
}
if(args.count("--list") > 0)
{
for(auto&& tc : get_test_cases())
out() << tc.first << std::endl;
return;
}
if(args.count("--quiet") > 0)
quiet = true;
auto cases = args[""];
if(cases.empty())
{
for(auto&& tc : get_test_cases())
run_test_case(tc.first, tc.second, args);
}
else
{
std::unordered_map<std::string, test_case> m(get_test_cases().begin(),
get_test_cases().end());
for(auto&& iname : cases)
{
for(auto&& name : get_case_names(iname))
{
auto f = m.find(name);
if(f == m.end())
{
out() << color::fg_red << "[ ERROR ] Test case '" << name
<< "' not found." << color::reset << std::endl;
failed.push_back(name);
}
else
run_test_case(name, f->second, args);
}
}
}
out() << color::fg_green << "[==========] " << color::fg_yellow << ran << " tests ran"
<< color::reset << std::endl;
if(not failed.empty())
{
out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << failed.size()
<< " tests failed" << color::reset << std::endl;
for(auto&& name : failed)
out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << name
<< color::reset << std::endl;
std::exit(1);
}
}
std::function<std::vector<std::string>(const std::string&)> get_case_names =
[](const std::string& name) -> std::vector<std::string> { return {name}; };
std::vector<argument> arguments = {};
std::vector<std::string> failed = {};
std::size_t ran = 0;
bool quiet = false;
};
inline void run(int argc, const char* argv[])
{
driver d{};
d.run(argc, argv);
}
} // namespace test
......@@ -404,7 +636,7 @@ inline void run(int argc, const char* argv[])
__PRETTY_FUNCTION__, \
__FILE__, \
__LINE__, \
&std::abort)
&test::fail)
// NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment