Commit 3df20646 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 1005a693 d0543c96
......@@ -91,28 +91,34 @@ add_library(migraphx_device
device/unary_not.cpp
device/where.cpp
)
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_device)
target_compile_options(migraphx_device PRIVATE -std=c++17 -fno-gpu-rdc -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(migraphx_device migraphx hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument)
if(CMAKE_CXX_COMPILER MATCHES ".*hcc")
set(AMDGPU_TARGETS "gfx803;gfx900;gfx906" CACHE STRING "")
foreach(AMDGPU_TARGET ${AMDGPU_TARGETS})
target_compile_options(migraphx_device PRIVATE -amdgpu-target=${AMDGPU_TARGET})
target_link_libraries(migraphx_device -amdgpu-target=${AMDGPU_TARGET})
endforeach()
else()
target_compile_options(migraphx_device PRIVATE -Wno-cuda-compat)
endif()
add_library(compile_for_gpu INTERFACE)
target_compile_options(compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns)
target_link_libraries(compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument)
check_cxx_compiler_flag("--cuda-host-only -fhip-lambda-host-device -x hip" HAS_HIP_LAMBDA_HOST_DEVICE)
if(HAS_HIP_LAMBDA_HOST_DEVICE)
message(STATUS "Enable -fhip-lambda-host-device")
target_compile_options(migraphx_device PRIVATE -fhip-lambda-host-device)
target_compile_options(compile_for_gpu INTERFACE -fhip-lambda-host-device)
endif()
set_target_properties(migraphx_device PROPERTIES EXPORT_NAME device)
rocm_set_soversion(migraphx_device ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_device)
target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries(migraphx_device PRIVATE compile_for_gpu)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
add_library(kernel_file_check EXCLUDE_FROM_ALL)
foreach(KERNEL_FILE ${KERNEL_FILES})
get_filename_component(KERNEL_BASE_FILE ${KERNEL_FILE} NAME_WE)
file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp "#include <migraphx/kernels/${KERNEL_BASE_FILE}.hpp>\n")
target_sources(kernel_file_check PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/kernels/include/migraphx/kernels/${KERNEL_BASE_FILE}.cpp)
endforeach()
target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>)
target_link_libraries(kernel_file_check compile_for_gpu)
rocm_clang_tidy_check(kernel_file_check)
add_library(migraphx_gpu
abs.cpp
analyze_streams.cpp
......@@ -310,8 +316,12 @@ target_flags(HIP_COMPILER_FLAGS hip::device)
# Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REPLACE "$<LINK_LANGUAGE:CXX>" "1" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REPLACE "SHELL:" "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ")
foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach()
message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}")
target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER=${CMAKE_CXX_COMPILER}"
......@@ -341,7 +351,7 @@ target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
add_subdirectory(driver)
rocm_install_targets(
TARGETS migraphx_gpu migraphx_device
TARGETS migraphx_gpu migraphx_device compile_for_gpu
INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/include
)
......
......@@ -9,7 +9,7 @@ namespace gpu {
shape hip_argmax::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
check_shapes{inputs, *this}.has(2);
return op.normalize_compute_shape({inputs.at(0)});
}
......
......@@ -9,7 +9,7 @@ namespace gpu {
shape hip_argmin::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).standard();
check_shapes{inputs, *this}.has(2);
return op.normalize_compute_shape({inputs.at(0)});
}
......
......@@ -108,12 +108,13 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp =
generate_args_hpp(options.reduced_inputs.empty() ? options.inputs : options.reduced_inputs);
generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.push_back(src_file{fs::path{"args.hpp"},
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " ");
options.params += " -ftemplate-backtrace-limit=0";
options.params += " -Werror";
auto cos = compile_hip_src(srcs, std::move(options.params), get_device_name());
if(cos.size() != 1)
......
......@@ -20,7 +20,7 @@ static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
using namespace migraphx;
namespace migraphx {
${preamble}
......@@ -32,6 +32,8 @@ __global__ void kernel(${params})
}
} // namespace migraphx
int main() {}
)__migraphx__";
......@@ -46,7 +48,7 @@ operation compile_pointwise(context&,
options.local = 1024;
options.inputs = inputs;
options.output = inputs.back();
options.reduced_inputs = reduce_dims(inputs);
options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal";
auto src = interpolate_string(pointwise_kernel,
{{"params", enum_params(inputs.size(), "void * private_p")},
......@@ -60,8 +62,17 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
{
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
auto name = g.create_function(g.generate_module(m).set_attributes({"__device__"}));
return compile_pointwise((ctx), inputs, "&" + name, g.str());
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
auto name =
g.create_function(g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m));
return compile_pointwise((ctx), inputs, "MIGRAPHX_LIFT(" + name + ")", g.str());
}
} // namespace gpu
......
......@@ -14,17 +14,29 @@ namespace gpu {
static const char* const roialign_kernel = R"__migraphx__(
#include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
using namespace migraphx;
namespace migraphx {
extern "C" {
__global__ void roialign_kernel(void* in_x, void* in_rois, void* in_ind, void* y)
{
make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) { roialign(xs...); });
make_tensors()(in_x, in_rois, in_ind, y)([](auto&&... xs) {
auto settings = make_roalign_settings(MIGRAPHX_MAKE_CONSTANT(float{ROIS_OFFSET}),
_c<bool{IS_AVG_POOLING}>,
_c<int64_t{SAMPLING_RATIO}>,
MIGRAPHX_MAKE_CONSTANT(float{SPATIAL_SCALE}));
roialign(xs..., settings);
});
}
}
} // namespace migraphx
int main() {}
)__migraphx__";
......@@ -38,7 +50,7 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const
options.inputs = io_shapes;
options.output = out_s;
options.kernel_name = "roialign_kernel";
options.reduced_inputs = io_shapes;
options.virtual_inputs = io_shapes;
// sampling_ratio
assert(val.contains("sampling_ratio"));
......
......@@ -75,8 +75,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(
inline auto gs_launch(hipStream_t stream, index_int n, index_int local = 1024)
{
index_int groups = (n + local - 1) / local;
index_int nglobal = std::min<index_int>(1048576, groups) * local;
index_int groups = (n + local - 1) / local;
// max possible number of blocks is set to 1B (1,073,741,824)
index_int nglobal = std::min<index_int>(1073741824, groups) * local;
return [=](auto f) {
launch(stream, nglobal, local)([=](auto idx) __device__ {
......
......@@ -21,33 +21,58 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const index_int max_block_size = 120;
// const index_int max_block_size = 128;
const index_int block_size = compute_block_size(batch_item_num, max_block_size);
gs_launch(stream,
batch_shape.elements() * block_size,
block_size)([=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest();
if(axis == batch_lens.size() - 1)
{
gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto start_loc = i / block_size * batch_item_num;
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
return input[start_loc + j];
});
auto batch_sum = block_reduce<max_block_size>(
idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
auto val = input[start_loc + j] - batch_max;
return ::exp(to_hip_type(val));
});
auto batch_sum =
block_reduce<max_block_size>(idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
idx.local_stride(batch_item_num, [&](auto j) __device__ {
auto val = input[start_loc + j] - batch_max;
output[start_loc + j] = ::exp(to_hip_type(val)) / batch_sum;
});
});
}
else
{
gs_launch(stream, batch_shape.elements() * block_size, block_size)(
[=](auto i, auto idx) __device__ {
auto data_idx = batch.multi(i / block_size);
auto batch_max = block_reduce<max_block_size>(
idx, max{}, init, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
return input[data_idx];
});
idx.local_stride(batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
});
});
auto batch_sum = block_reduce<max_block_size>(
idx, sum{}, 0, batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
return ::exp(to_hip_type(val));
});
idx.local_stride(batch_item_num, [&](auto j) __device__ {
data_idx[axis] = j;
auto val = input[data_idx] - batch_max;
output[data_idx] = ::exp(to_hip_type(val)) / batch_sum;
});
});
}
});
}
......
......@@ -62,6 +62,8 @@ struct fusion
keep_alive(std::move(t));
}
bool empty() const { return fp == nullptr; }
op_t operator[](std::size_t i) const
{
assert(fp);
......@@ -125,12 +127,11 @@ struct fusion
return shape{shape::int8_type, {ws_size}};
}
void compile(context& ctx)
bool compile(context& ctx)
{
assert(fp);
auto status = miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get());
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("Compiling fusion plan failed");
return miopenCompileFusionPlan(ctx.get_stream().get_miopen(), fp.get()) ==
miopenStatusSuccess;
}
argument execute(context& ctx,
......@@ -169,7 +170,7 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
{
const auto device_name = split_string(get_device_name(), ':').front();
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(not contains(get_supported_archs(), device_name))
return false;
if(enabled(MIGRAPHX_DISABLE_MIOPEN_FUSION{}))
......@@ -561,6 +562,117 @@ struct find_mul_add_relu
}
};
struct miopen_fusion
{
struct fuse_op_data
{
operation op;
float alpha = 1;
float beta = 0;
};
struct fuse_op : fuse_op_data, reflect_equality<fuse_op>, reflect_stream<fuse_op>
{
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"), f(self.alpha, "alpha"), f(self.beta, "beta"));
}
};
std::vector<fuse_op> ops = {};
fusion f = {};
std::function<void(context&, const fusion&, const std::vector<argument>&)> execute;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.ops, "ops"));
}
value compile(context& ctx, const shape&, std::vector<shape> inputs)
{
// Compensate for allocation
inputs.pop_back();
std::size_t i = 0;
f = fusion(inputs[i]);
i++;
std::vector<std::function<void(const fused_operator_args&, const std::vector<argument>&)>>
invokers;
for(auto&& fop : ops)
{
if(i > inputs.size())
{
f = {};
return {};
}
if(fop.op.name() == "convolution")
{
auto* mop = f.create_conv(any_cast<op::convolution>(fop.op), inputs[i]);
invokers.push_back(
[=](const fused_operator_args& fargs, const std::vector<argument>& args) {
miopenSetOpArgsConvForward(
fargs.get(), mop, &fop.alpha, &fop.beta, args[i].implicit());
});
i++;
}
else if(fop.op.name() == "add")
{
auto* mop = f.create_bias(inputs[i]);
invokers.push_back(
[=](const fused_operator_args& fargs, const std::vector<argument>& args) {
miopenSetOpArgsBiasForward(
fargs.get(), mop, &fop.alpha, &fop.beta, args[i].implicit());
});
i++;
}
else if(fop.op.name() == "relu")
{
auto* mop = f.create_relu();
invokers.push_back([=](const fused_operator_args& fargs,
const std::vector<argument>&) {
miopenSetOpArgsActivForward(fargs.get(), mop, &fop.alpha, &fop.beta, 0, 0, 0);
});
}
else
{
f = {};
return {};
}
}
if(not f.compile(ctx))
{
f = {};
return {};
}
execute = [invokers](context& c, const fusion& ff, const std::vector<argument>& args) {
auto fargs = make_fused_args();
for(auto&& invoker : invokers)
invoker(fargs, args);
ff.execute(c, fargs, args.front(), args.back());
};
return {{"workspace", f.get_workspace(ctx).bytes()}};
}
void finalize(context& ctx, const shape& output_shape, const std::vector<shape>& inputs)
{
if(not f.empty())
return;
auto v = compile(ctx, output_shape, inputs);
if(not v.is_object())
MIGRAPHX_THROW("Failed to compile fusion plan");
}
std::string name() const { return "gpu::miopen_fusion"; }
shape compute_shape(const std::vector<shape>& inputs) const
{
if(ops.empty())
return {};
// TODO: Check number of arguments
return ops.front().op.compute_shape({inputs[0], inputs[1]});
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
execute(ctx, f, args);
return args.back();
}
};
struct miopen_conv_bias
{
op::convolution op;
......@@ -596,7 +708,8 @@ struct miopen_conv_bias
f = fusion(inputs[0]);
conv = f.create_conv(op, inputs[1]);
bias = f.create_bias(inputs[3]);
f.compile(ctx);
if(not f.compile(ctx))
MIGRAPHX_THROW("Failed to compile fusion plan");
}
shape get_workspace(context& ctx) { return f.get_workspace(ctx); }
......@@ -683,6 +796,25 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
p.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
}
inline auto precompile_name(std::string s) // NOLINT
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "gpu::precompile_op")
return false;
auto op = from_value<operation>(ins->get_operator().to_value().at("op"));
return (op.name() == s);
});
}
template <class... Ms>
auto conv_bias_pointwise(Ms... ms)
{
return precompile_name("pointwise")(
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")),
ms...);
}
struct find_conv_bias
{
context* ctx = nullptr;
......@@ -709,6 +841,46 @@ struct find_conv_bias_relu
}
};
struct find_conv_pointwise
{
context* ctx = nullptr;
auto matcher() const
{
return precompile_name("pointwise")(
match::nargs(3),
match::either_arg(0, 1)(bias_shape(match::used_once()).bind("bias"),
fusable_conv(match::used_once()).bind("conv")));
}
void apply(module& m, match::matcher_result r) const
{
auto conv_ins = r.instructions["conv"];
auto bias_ins = r.instructions["bias"];
auto ins = r.result;
auto input_ins = conv_ins->inputs().at(0);
auto weights_ins = conv_ins->inputs().at(1);
auto conv_op = any_cast<miopen_convolution>(conv_ins->get_operator()).op;
auto alloc_ins = ins->inputs().back();
module_ref pm = ins->module_inputs().front();
miopen_fusion op{};
op.ops.push_back({{conv_op}});
for(auto&& i : *pm)
{
if(i.name()[0] == '@')
continue;
auto inputs = to_shapes(i.inputs());
op.ops.push_back({{i.get_operator()}});
}
std::vector<instruction_ref> inputs = {input_ins, weights_ins, bias_ins, alloc_ins};
auto v = op.compile(*ctx, ins->get_shape(), to_shapes(inputs));
if(not v.is_object())
return;
m.replace_instruction(ins, op, inputs);
}
};
struct find_gemm_add
{
auto matcher() const
......@@ -778,6 +950,7 @@ void fuse_ops::apply(module& p) const
match::find_matches(p, find_triadd{});
match::find_matches(p,
find_layernorm{},
find_conv_pointwise{ctx},
find_conv_bias_relu{ctx},
find_conv_bias{ctx},
find_add_gelu{},
......
......@@ -42,7 +42,8 @@ void gemm_impl(context& ctx,
const std::vector<argument>& args,
T alpha,
T beta,
bool int8_x4_format)
bool int8_x4_format,
bool compute_fp32)
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
......@@ -65,13 +66,11 @@ void gemm_impl(context& ctx,
output_type = rocblas_datatype_i32_r;
}
auto compute_type = output_type;
if(args[0].get_shape().type() == shape::half_type)
compute_type = rocblas_datatype_f32_r;
// if(ctx.get_stream().get_device_name() == "gfx908")
// {
// if(args[0].get_shape().type() == shape::half_type)
// compute_type = rocblas_datatype_f32_r;
// }
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag =
......@@ -84,6 +83,13 @@ void gemm_impl(context& ctx,
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
if(compute_fp32)
{
alpha_r = alpha;
beta_r = beta;
}
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
......@@ -109,14 +115,14 @@ void gemm_impl(context& ctx,
n,
m,
k,
&alpha,
&alpha_r,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
&beta,
&beta_r,
to_pointer(args[2]),
output_type,
ldc,
......@@ -137,7 +143,7 @@ void gemm_impl(context& ctx,
n,
m,
k,
&alpha,
&alpha_r,
to_pointer(args.at(1)),
arg_type,
ldb,
......@@ -146,7 +152,7 @@ void gemm_impl(context& ctx,
arg_type,
lda,
m * k,
&beta,
&beta_r,
to_pointer(args[2]),
output_type,
ldc,
......@@ -169,9 +175,10 @@ void gemm(context& ctx,
const std::vector<argument>& args,
float alpha,
float beta,
bool int8_x4_format)
bool int8_x4_format,
bool compute_fp32)
{
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format);
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
}
void gemm(context& ctx,
......@@ -179,9 +186,10 @@ void gemm(context& ctx,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool int8_x4_format)
bool int8_x4_format,
bool compute_fp32)
{
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format);
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
}
} // namespace gpu
......
......@@ -16,7 +16,7 @@ struct hip_compile_options
shape output;
std::string kernel_name = "kernel";
std::string params = "";
std::vector<shape> reduced_inputs = {};
std::vector<shape> virtual_inputs = {};
};
operation compile_hip_code_object(const std::string& content, hip_compile_options options);
......
......@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
size_t batch_item_num = batch_lens[axis];
batch_lens[axis] = 1;
migraphx::shape batch_shape{arg_shape.type(), batch_lens};
migraphx::shape std_arg_shape{arg_shape.type(), arg_shape.lens()};
hip_visit_all(arg, arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
hip_visit_all(arg, std_arg_shape, batch_shape)([&](auto input, auto arg_s, auto batch_s) {
auto* output = device_cast(result.get<int64_t>().data());
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
// use one block for items in one batch.
......
......@@ -25,6 +25,7 @@ struct rocblas_gemm
float alpha = 1;
float beta = 0;
bool int8_x4_format = true;
bool compute_fp32 = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
......@@ -80,11 +81,11 @@ struct rocblas_gemm
{
if(this->name() == "gpu::gemm")
{
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format);
gemm(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
}
else
{
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), int8_x4_format);
gemm(ctx, output_shape, args, int32_t(alpha), int32_t(beta), int8_x4_format, compute_fp32);
}
return args.back();
}
......
......@@ -14,13 +14,15 @@ void gemm(context& ctx,
const std::vector<argument>& args,
float alpha,
float beta,
bool int8_x4_format);
bool int8_x4_format,
bool compute_fp32);
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
int32_t alpha,
int32_t beta,
bool int8_x4_format);
bool int8_x4_format,
bool compute_fp32);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -176,23 +176,23 @@ struct array
}
};
template <class T, T... xs>
struct integral_const_array : array<T, sizeof...(xs)>
template <class T, T... Xs>
struct integral_const_array : array<T, sizeof...(Xs)>
{
using base_array = array<T, sizeof...(xs)>;
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({xs...}) {}
using base_array = array<T, sizeof...(Xs)>;
MIGRAPHX_DEVICE_CONSTEXPR integral_const_array() : base_array({Xs...}) {}
};
template <class T, T... xs, class F>
constexpr auto transform(integral_const_array<T, xs...>, F f)
template <class T, T... Xs, class F>
constexpr auto transform(integral_const_array<T, Xs...>, F f)
{
return integral_const_array<T, f(xs)...>{};
return integral_const_array<T, f(Xs)...>{};
}
template <class T, T... xs, class U, U... ys, class F>
constexpr auto transform(integral_const_array<T, xs...>, integral_const_array<U, ys...>, F f)
template <class T, T... Xs, class U, U... Ys, class F>
constexpr auto transform(integral_const_array<T, Xs...>, integral_const_array<U, Ys...>, F f)
{
return integral_const_array<T, f(xs, ys)...>{};
return integral_const_array<T, f(Xs, Ys)...>{};
}
template <index_int... Ns>
......
#ifndef MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#define MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
#include <hip/hip_runtime.h>
#include <migraphx/kernels/hip.hpp>
namespace migraphx {
inline __host__ __device__ void
assert_fail(const char* assertion, const char* file, unsigned int line, const char* function)
#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__
#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__)
// Workaround hip's broken abort on device code
#ifdef __HIP_DEVICE_COMPILE__
// NOLINTNEXTLINE
#define MIGRAPHX_HIP_NORETURN
#else
// NOLINTNEXTLINE
#define MIGRAPHX_HIP_NORETURN [[noreturn]]
#endif
namespace debug {
struct swallow
{
template <class... Ts>
constexpr swallow(Ts&&...)
{
}
};
template <size_t N>
struct print_buffer
{
char buffer[N + 1] = {0};
char* pos = buffer;
constexpr void append(char c)
{
if(c == 0)
return;
if(pos < buffer + N)
{
*pos = c;
pos++;
}
}
template <size_t M>
constexpr void append(const char (&array)[M])
{
for(int i = 0; i < M; i++)
append(array[i]);
}
};
template <class... Ts>
__host__ __device__ void print(const Ts&... xs)
{
const auto size = (sizeof(xs) + ...);
print_buffer<size> buffer;
swallow{(buffer.append(xs), 0)...};
printf("%s", buffer.buffer);
}
} // namespace debug
// noreturn cannot be used on this function because abort in hip is broken
template <class T1, class T2, class T3, class T4>
MIGRAPHX_HIP_NORETURN inline __host__ __device__ void
assert_fail(const T1& assertion, const T2& file, const T3& line, const T4& function)
{
printf("%s:%u: %s: assertion '%s' failed.\n", file, line, function, assertion);
// printf is broken on hip with more than one argument, so use a simple print functions instead
debug::print(file, ":", line, ": ", function, ": assertion '", assertion, "' failed.\n");
// printf("%s:%s: %s: assertion '%s' failed.\n", file, line, function, assertion);
abort();
}
#ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \
((cond) ? void(0) : [](auto... xs) { \
assert_fail(xs...); \
}(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__))
#define MIGRAPHX_ASSERT(cond) \
((cond) ? void(0) : [](auto&&... private_migraphx_xs) { \
assert_fail(private_migraphx_xs...); \
}(#cond, __FILE__, MIGRAPHX_STRINGIZE(__LINE__), __PRETTY_FUNCTION__))
#else
#define MIGRAPHX_ASSERT(cond)
#endif
......
......@@ -16,6 +16,19 @@ struct swallow
template <index_int>
using ignore = swallow;
template <class... Fs>
struct overloaded : Fs...
{
using Fs::operator()...;
overloaded(Fs... fs) : Fs(fs)... {}
};
template <class... Fs>
overloaded<Fs...> overload(Fs... fs)
{
return {fs...};
}
namespace detail {
template <class R>
......@@ -124,12 +137,48 @@ constexpr void each_args(F)
{
}
template <class F, class T>
constexpr auto fold_impl(F&&, T&& x)
{
return static_cast<T&&>(x);
}
template <class F, class T, class U, class... Ts>
constexpr auto fold_impl(F&& f, T&& x, U&& y, Ts&&... xs)
{
return fold_impl(f, f(static_cast<T&&>(x), static_cast<U&&>(y)), static_cast<Ts&&>(xs)...);
}
template <class F>
constexpr auto fold(F f)
{
return [=](auto&&... xs) { return fold_impl(f, static_cast<decltype(xs)&&>(xs)...); };
}
template <class... Ts>
auto pack(Ts... xs)
constexpr auto pack(Ts... xs)
{
return [=](auto f) { return f(xs...); };
}
template <class Compare, class P1, class P2>
constexpr auto pack_compare(Compare compare, P1 p1, P2 p2)
{
return p1([&](auto... xs) {
return p2([&](auto... ys) {
auto c = [&](auto x, auto y) -> int {
if(compare(x, y))
return 1;
else if(compare(y, x))
return -1;
else
return 0;
};
return fold([](auto x, auto y) { return x ? x : y; })(c(xs, ys)..., 0);
});
});
}
template <index_int N>
constexpr auto arg_c()
{
......@@ -168,8 +217,13 @@ constexpr auto transform_args(F f, Fs... fs)
return [=](auto... xs) { return transform_args(f)(xs...)(transform_args(fs...)); };
}
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
([](auto&&... xs) { return (__VA_ARGS__)(static_cast<decltype(xs)>(xs)...); })
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
#define MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
namespace migraphx {
template <class F>
struct generic_constant
{
static constexpr auto value = F{}();
using value_type = decltype(value);
using type = generic_constant;
constexpr operator value_type() const noexcept { return value; }
constexpr value_type operator()() const noexcept { return value; }
};
template <class F>
constexpr generic_constant<F> make_generic_constant(F)
{
return {};
}
// NOLINTNEXTLINE
#define MIGRAPHX_MAKE_CONSTANT(x) \
make_generic_constant([] { \
struct fun \
{ \
constexpr auto operator()() const { return x; } \
}; \
return fun{}; \
}())
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GENERIC_CONSTANT_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_HIP_HPP
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h>
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#define MIGRAPHX_GUARD_KERNELS_INDEX_HPP
#include <hip/hip_runtime.h>
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/types.hpp>
namespace migraphx {
......@@ -17,7 +17,7 @@ struct index
#ifdef MIGRAPHX_NGLOBAL
return MIGRAPHX_NGLOBAL;
#else
return blockDim.x * gridDim.x;
return blockDim.x * gridDim.x; // NOLINT
#endif
}
......@@ -26,7 +26,7 @@ struct index
#ifdef MIGRAPHX_NLOCAL
return MIGRAPHX_NLOCAL;
#else
return blockDim.x;
return blockDim.x; // NOLINT
#endif
}
......@@ -53,7 +53,7 @@ struct index
inline __device__ index make_index()
{
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x};
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
}
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment