"src/targets/vscode:/vscode.git/clone" did not exist on "493bb8d562a222292138b3d5db033b9da6a53960"
Commit 11e155c2 authored by Paul's avatar Paul
Browse files

Merge

parents 8a9c5bce aa7ff911
......@@ -154,6 +154,13 @@ struct hip_device
std::size_t get_cu_count() const { return device_props.multiProcessorCount; }
std::size_t get_max_workitems_per_cu() const
{
return device_props.maxThreadsPerMultiProcessor;
}
std::size_t get_max_workitems_per_block() const { return device_props.maxThreadsPerBlock; }
private:
std::size_t device_id = 0;
std::size_t current_stream = 0;
......@@ -235,6 +242,8 @@ struct context
this->current_device = std::make_shared<hip_device>(0, n_streams);
}
any_ptr get_queue() { return get_stream().get(); }
private:
// TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device;
......
......@@ -10,7 +10,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void prefix_scan_sum(hipStream_t stream, const argument& result, const argument& arg, int32_t axis);
void prefix_scan_sum(hipStream_t stream,
const argument& result,
const argument& arg,
int32_t axis,
bool exclusive,
bool reverse);
} // namespace device
} // namespace gpu
......
......@@ -14,7 +14,7 @@ namespace gpu {
struct eliminate_workspace
{
std::string name() const { return "eliminate_workspace"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -16,7 +16,7 @@ struct fuse_ops
context* ctx = nullptr;
bool fast_math = true;
std::string name() const { return "gpu::fuse_ops"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace gpu
......
......@@ -18,6 +18,8 @@ namespace gpu {
struct context;
void blas_shape(const shape& s);
template <class Op>
struct rocblas_gemm
{
......@@ -25,6 +27,7 @@ struct rocblas_gemm
float alpha = 1;
float beta = 0;
bool int8_x4_format = true;
bool compute_fp32 = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -49,13 +52,14 @@ struct rocblas_gemm
std::vector<shape> in_shapes(inputs);
in_shapes.pop_back();
check_shapes{in_shapes, *this}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
blas_shape(inputs[0]);
blas_shape(inputs[1]);
// if gemm and add are fused
if(not float_equal(beta, 0))
if(in_shapes.size() > 2)
{
auto cmat_shape = in_shapes.back();
in_shapes.pop_back();
blas_shape(cmat_shape);
auto op_out_shape = op.compute_shape(in_shapes);
if(cmat_shape.lens() != op_out_shape.lens())
{
......@@ -70,6 +74,7 @@ struct rocblas_gemm
to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type()));
}
return op_out_shape;
}
return op.compute_shape(in_shapes);
......@@ -80,37 +85,21 @@ struct rocblas_gemm
{
if(this->name() == "gpu::gemm")
{
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format);
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
}
else
{
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), int8_x4_format);
gemm(ctx,
output_shape,
args,
int32_t(alpha),
int32_t(beta),
int8_x4_format,
compute_fp32);
}
return args.back();
}
void batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
{
MIGRAPHX_THROW("GPU_GEMM: matrix size and batch size {" + to_string_range(strides) +
"} are transposed!");
}
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("GPU_GEMM: batch size {" + to_string_range(strides) +
"} is transposed!");
}
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
......
......@@ -14,13 +14,15 @@ void gemm(context& ctx,
const std::vector<argument>& args,
float alpha,
float beta,
bool int8_x4_format);
bool int8_x4_format,
bool compute_fp32);
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool int8_x4_format);
bool int8_x4_format,
bool compute_fp32);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -9,6 +9,8 @@
#include <miopen/miopen.h>
#include <migraphx/config.hpp>
#include <sstream>
#ifdef HAS_FIND_MODE_API
extern "C" miopenStatus_t miopenHiddenSetConvolutionFindMode(miopenConvolutionDescriptor_t convDesc,
int findMode);
......@@ -132,12 +134,16 @@ inline convolution_descriptor make_deconv(const T& op)
inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
{
miopenPoolingMode_t mode;
if(op.mode == "max")
if(op.mode == op::pooling_mode::max)
mode = miopenPoolingMax;
else if(op.mode == "average")
else if(op.mode == op::pooling_mode::average)
mode = miopenPoolingAverage;
else
MIGRAPHX_THROW("Unknown mode for pooling: " + op.mode);
{
std::stringstream ss("Unknown mode for pooling: ");
ss << op.mode;
MIGRAPHX_THROW(ss.str());
}
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
int kdims = op.kdims();
......
......@@ -40,9 +40,8 @@ struct hip_prefix_scan_sum : oper<hip_prefix_scan_sum>
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
if(op.exclusive or op.reverse)
MIGRAPHX_THROW("Exclusive and reverse scan not supported");
device::prefix_scan_sum(ctx.get_stream().get(), args[1], args[0], op.axis);
device::prefix_scan_sum(
ctx.get_stream().get(), args[1], args[0], op.axis, op.exclusive, op.reverse);
return args[1];
}
......
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_ROIALIGN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_ROIALIGN_HPP
#ifndef MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct context;
operation compile_roialign(context& ctx, const std::vector<shape>& io_shapes, const value& val);
struct prefuse_ops
{
std::string name() const { return "gpu::prefuse_ops"; }
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_ROIALIGN_HPP
#endif // MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
......@@ -2,6 +2,7 @@
#define MIGRAPHX_GUARD_RTGLIB_QUANT_CONVOLUTION_HPP
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/quant_convolution.hpp>
#include <migraphx/gpu/miopen.hpp>
......@@ -14,6 +15,7 @@ struct context;
struct miopen_quant_convolution
{
op::quant_convolution op;
bool int8_x4_format = false;
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr;
......@@ -22,7 +24,8 @@ struct miopen_quant_convolution
static auto reflect(Self& self, F f)
{
// TODO: Add algo
return op::quant_convolution::reflect(self.op, f);
return pack_join(migraphx::reflect(self.op, f),
pack(f(self.int8_x4_format, "int8_x4_format")));
}
std::string name() const { return "gpu::quant_convolution"; }
......
......@@ -3,7 +3,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/scatter_none.hpp>
#include <migraphx/gpu/miopen.hpp>
namespace migraphx {
......@@ -14,7 +14,9 @@ struct context;
struct hip_scatter
{
op::scatter op;
// scatter_none is an exact replacement for previous op::scatter,
// renamed to match an Onnx option. Don't use base class op::scatter
op::scatter_none op;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -22,7 +24,7 @@ struct hip_scatter
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::scatter"; }
std::string name() const { return "gpu::scatter_none"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
......
......@@ -17,9 +17,9 @@ struct schedule_model
{
std::size_t streams = 0;
std::size_t concurrency() const;
void sched(module& p, instruction_ref ins, std::size_t n) const;
void wait(module& p, instruction_ref ins, std::size_t wait_id) const;
void record(module& p, instruction_ref ins, std::size_t wait_id) const;
void sched(module& m, instruction_ref ins, std::size_t n) const;
void wait(module& m, instruction_ref ins, std::size_t wait_id) const;
void record(module& m, instruction_ref ins, std::size_t wait_id) const;
std::size_t weight(const operation& op) const;
};
......
......@@ -15,7 +15,7 @@ namespace gpu {
struct sync_device
{
std::string name() const { return "sync_device"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -14,7 +14,7 @@ struct write_literals
context* ctx = nullptr;
std::string name() const { return "gpu::write_literals"; }
void apply(module& p) const;
void apply(module& m) const;
};
} // namespace gpu
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const gathernd_kernel = R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
{
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
gathernd(xs..., settings);
});
}
}
} // namespace migraphx
)__migraphx__";
struct gathernd_compiler : compiler<gathernd_compiler>
{
std::vector<std::string> names() const { return {"gathernd"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto out_s = inputs.back();
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
options.inputs = inputs;
options.output = out_s;
options.kernel_name = "gathernd_kernel";
options.virtual_inputs = inputs;
// batch_dims
assert(v.contains("batch_dims"));
auto batch_dims = v.at("batch_dims").to<int64_t>();
options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims);
return compile_hip_code_object(gathernd_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
pointwise(idx, ${transformers})(${lambda}, ${args});
}
}
} // namespace migraphx
)__migraphx__";
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
struct pointwise_compiler : compiler<pointwise_compiler>
{
std::vector<std::string> names() const { return {"pointwise"}; }
static std::size_t oversubscribe_if(bool b)
{
if(b)
return 256;
else
return 1;
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params(
v,
compute_global_for(ctx,
options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading())));
auto src = interpolate_string(pointwise_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(preloads, vec)},
{"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation&) const
{
assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front();
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign",
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(
compile_op(ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const simple_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void reduce_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
});
}
}
} // namespace migraphx
)__migraphx__";
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
std::vector<std::size_t> reduce_lens;
std::transform(output_lens.begin(),
output_lens.end(),
input_lens.begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> std::size_t {
if(x == y)
return 1;
else
return y;
});
return reduce_lens;
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
rlens.begin(),
rlens.end(),
inputs.front().strides().begin(),
init,
[](auto x, auto y) { return std::min(x, y); },
[](auto len, auto stride) { return len == 1 ? init : stride; });
if(min_stride > 2)
return "lane";
return "block";
}
struct reduce_compiler : compiler<reduce_compiler>
{
std::vector<std::string> names() const
{
return {"reduce", "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_prod"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = reduce_dims(inputs);
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1)
{
vec = vectorize::elements(faxis, options.virtual_inputs);
}
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs));
if(algo == "block")
{
auto block_size = compute_block_size(relements, 256);
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else if(algo == "lane")
{
options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
}
else
{
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.kernel_name = "reduce_kernel";
std::string identity = "[](auto x) { return x; }";
auto src = interpolate_string(simple_reduce_kernel,
{{"reduction", v.at("reduction").to<std::string>()},
{"init", v.get("init", std::string{"0"})},
{"read", v.get("read", identity)},
{"write", v.get("write", identity)},
{"algo", algo},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
value v = value::object{};
auto reduce_elements = get_reduce_elements(ins->inputs());
if(op.name() == "reduce_sum")
{
v["reduction"] = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
v["reduction"] = "op::sum{}";
v["write"] = "op::mean{" + std::to_string(reduce_elements) + "}";
}
else if(op.name() == "reduce_max")
{
v["reduction"] = "op::max{}";
v["init"] = "lowest{}";
}
else if(op.name() == "reduce_min")
{
v["reduction"] = "op::min{}";
v["init"] = "highest{}";
}
else if(op.name() == "reduce_prod")
{
v["reduction"] = "op::product{}";
v["init"] = "1";
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compile_roialign.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -13,7 +19,6 @@ namespace gpu {
// NOLINTNEXTLINE
static const char* const roialign_kernel = R"__migraphx__(
#include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
......@@ -37,46 +42,46 @@ __global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y
} // namespace migraphx
int main() {}
)__migraphx__";
operation compile_roialign(context&, const std::vector<shape>& io_shapes, const value& val)
struct roialign_compiler : compiler<roialign_compiler>
{
hip_compile_options options;
auto out_s = io_shapes.back();
options.local = 128;
options.global = compute_global(out_s.elements(), options.local);
options.inputs = io_shapes;
options.output = out_s;
options.kernel_name = "roialign_kernel";
options.virtual_inputs = io_shapes;
// sampling_ratio
assert(val.contains("sampling_ratio"));
auto sampling_ratio = val.at("sampling_ratio").to<int64_t>();
options.params += " -DSAMPLING_RATIO=" + std::to_string(sampling_ratio);
// pooling_mode
assert(val.contains("mode"));
auto mode = val.at("mode").to<std::string>();
bool is_avg_pooling = (mode == "avg");
options.params += " -DIS_AVG_POOLING=" + std::to_string(static_cast<int>(is_avg_pooling));
// coord_trans_mode
assert(val.contains("coordinate_transformation_mode"));
auto ctm = val.at("coordinate_transformation_mode").to<std::string>();
float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f;
options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset);
// spatial_scale
assert(val.contains("spatial_scale"));
float spatial_scale = val.at("spatial_scale").to<float>();
options.params += " -DSPATIAL_SCALE=" + std::to_string(spatial_scale);
return compile_hip_code_object(roialign_kernel, options);
}
} // namespace gpu
std::vector<std::string> names() const { return {"roialign"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), 128);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "roialign_kernel";
// sampling_ratio
options.params += " -DSAMPLING_RATIO=" + v.at("sampling_ratio").to<std::string>();
// pooling_mode
auto mode = v.at("mode").to<migraphx::op::pooling_mode>();
std::string is_avg_pooling =
(mode == migraphx::op::pooling_mode::average) ? "true" : "false";
options.params += " -DIS_AVG_POOLING=" + is_avg_pooling;
// coord_trans_mode
auto ctm = v.at("coordinate_transformation_mode").to<std::string>();
float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f;
options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset);
// spatial_scale
options.params += " -DSPATIAL_SCALE=" + v.at("spatial_scale").to<std::string>();
return compile_hip_code_object(roialign_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
// NOLINTNEXTLINE
static const char* const scatternd_kernel = R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., ${reduction}{});
});
}
}
} // namespace migraphx
)__migraphx__";
struct scatternd_compiler : compiler<scatternd_compiler>
{
std::vector<std::string> names() const
{
return {"scatternd_none", "scatternd_add", "scatternd_mul"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.at(1).elements()));
options.inputs = inputs;
options.output = inputs.back();
options.kernel_name = "scatternd_kernel";
options.virtual_inputs = inputs;
auto reduction = "assign_" + v.get("reduction", std::string{"none"});
auto src = interpolate_string(scatternd_kernel, {{"reduction", reduction}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(starts_with(op.name(), "scatternd_"));
auto reduction = op.name().substr(10);
return insert(compile_op(ctx,
to_shapes({ins->inputs().begin() + 1, ins->inputs().end()}),
{{"reduction", reduction}}));
}
compiler_replace insert(const operation& op) const
{
return [=](module& m, instruction_ref ins) {
auto args = ins->inputs();
args.back() =
m.insert_instruction(ins, make_op("hip::copy"), args.front(), args.back());
args.erase(args.begin());
return m.replace_instruction(ins, op, args);
};
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -59,6 +59,8 @@ void launch_kernel(hipFunction_t fun,
void* kernargs,
std::size_t size)
{
assert(global > 0);
assert(local > 0);
void* config[] = {
// HIP_LAUNCH_PARAM_* are macros that do horrible things
#ifdef MIGRAPHX_USE_CLANG_TIDY
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment