Commit 32b83c9c authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into inner_bcast_fix

parents 92f5a6cd 434a06cf
...@@ -26,7 +26,9 @@ ...@@ -26,7 +26,9 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/targets.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,7 +43,7 @@ struct index ...@@ -41,7 +43,7 @@ struct index
__device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT __device__ index_int nglobal() const { return blockDim.x * gridDim.x; } // NOLINT
__device__ index_int nlocal() const { return blockDim.x; } // NOLINT __device__ index_int nlocal() const { return blockDim.x; } // NOLINT
template <class F> template <class F>
__device__ void global_stride(index_int n, F f) const __device__ void global_stride(index_int n, F f) const
...@@ -81,6 +83,19 @@ inline auto launch(hipStream_t stream, index_int global, index_int local) ...@@ -81,6 +83,19 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
dim3 nthreads(local); dim3 nthreads(local);
// cppcheck-suppress UseDeviceLaunch // cppcheck-suppress UseDeviceLaunch
hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f); hipLaunchKernelGGL((launcher<f_type>), nblocks, nthreads, 0, stream, f);
hipError_t kernel_launch_status = hipGetLastError();
if(kernel_launch_status != hipSuccess)
{
std::string message = hipGetErrorString(kernel_launch_status);
if(not contains(get_targets(), get_device_name()))
{
message += ". Trying to run a kernel for " + get_device_name() +
" but MIGraphX was built for targets " + get_targets_as_string() +
". Please rebuild MIGraphX with -DGPU_TARGETS='" + get_device_name() +
"'.";
}
MIGRAPHX_THROW("MIGraphX device kernel failed to launch with error: " + message);
}
}; };
} }
......
...@@ -124,7 +124,7 @@ void nary_broadcast_vec_impl( ...@@ -124,7 +124,7 @@ void nary_broadcast_vec_impl(
buffer[i] = binput.data()[i]; buffer[i] = binput.data()[i];
} }
__syncthreads(); __syncthreads();
auto* bp = as_pointer(buffer); const auto* bp = as_pointer(buffer);
// Process the data // Process the data
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
...@@ -219,7 +219,7 @@ void nary_double_broadcast_vec_impl( ...@@ -219,7 +219,7 @@ void nary_double_broadcast_vec_impl(
buffer[i + bdim_vec_len] = binput2.data()[i]; buffer[i + bdim_vec_len] = binput2.data()[i];
} }
__syncthreads(); __syncthreads();
auto* bp = as_pointer(buffer); const auto* bp = as_pointer(buffer);
// Process the data // Process the data
for(size_t i = idx.global; i < nelements; i += nglobal) for(size_t i = idx.global; i < nelements; i += nglobal)
{ {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/device/targets.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
static std::vector<std::string> parse_targets() { return split_string(MIGRAPHX_GPU_TARGETS, ';'); }
const std::vector<std::string>& get_targets()
{
static auto result = parse_targets();
return result;
}
std::string get_targets_as_string() { return join_strings(get_targets(), ", "); }
static int get_device_id()
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
MIGRAPHX_THROW("No device");
return device;
}
std::string get_device_name()
{
hipDeviceProp_t props{};
auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties");
return props.gcnArchName;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp>
#include <string>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
const std::vector<std::string>& get_targets();
std::string get_targets_as_string();
std::string get_device_name();
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
...@@ -72,12 +72,12 @@ struct hip_heap_vector ...@@ -72,12 +72,12 @@ struct hip_heap_vector
index_int l = 2 * index + 1; index_int l = 2 * index + 1;
index_int r = 2 * index + 2; index_int r = 2 * index + 2;
if(l < n && compare(data[data_index(l)], data[data_index(index)])) if(l < n and compare(data[data_index(l)], data[data_index(index)]))
{ {
index = l; index = l;
} }
if(r < n && compare(data[data_index(r)], data[data_index(index)])) if(r < n and compare(data[data_index(r)], data[data_index(index)]))
{ {
index = r; index = r;
if(compare(data[data_index(l)], data[data_index(r)])) if(compare(data[data_index(l)], data[data_index(r)]))
......
...@@ -31,20 +31,6 @@ namespace migraphx { ...@@ -31,20 +31,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class HipDeviceProp>
std::string get_arch_name(rank<0>, const HipDeviceProp& props)
{
return "gfx" + std::to_string(props.gcnArch);
}
template <class HipDeviceProp>
auto get_arch_name(rank<1>, const HipDeviceProp& props) -> decltype(std::string(props.gcnArchName))
{
return std::string(props.gcnArchName);
}
std::string get_arch_name(const hipDeviceProp_t& props) { return get_arch_name(rank<1>{}, props); }
int get_device_id() int get_device_id()
{ {
int device; int device;
...@@ -60,7 +46,7 @@ std::string get_device_name() ...@@ -60,7 +46,7 @@ std::string get_device_name()
auto status = hipGetDeviceProperties(&props, get_device_id()); auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess) if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties"); MIGRAPHX_THROW("Failed to get device properties");
return get_arch_name(props); return props.gcnArchName;
} }
} // namespace gpu } // namespace gpu
......
...@@ -86,7 +86,7 @@ struct mlir_op ...@@ -86,7 +86,7 @@ struct mlir_op
size_t param_cnt = 0; size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names(); std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
for(std::string param_name : names) for(const std::string& param_name : names)
{ {
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++]; ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
} }
...@@ -103,7 +103,10 @@ struct mlir_op ...@@ -103,7 +103,10 @@ struct mlir_op
} }
if(ins->name() == "@return") if(ins->name() == "@return")
{ {
return ins_shapes[ins->inputs().at(0)].with_type(type); auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
} }
std::vector<shape> input_shapes; std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size()); input_shapes.resize(ins->inputs().size());
...@@ -119,6 +122,33 @@ struct mlir_op ...@@ -119,6 +122,33 @@ struct mlir_op
MIGRAPHX_REGISTER_OP(mlir_op); MIGRAPHX_REGISTER_OP(mlir_op);
namespace { namespace {
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op)
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{ {
...@@ -134,7 +164,7 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) ...@@ -134,7 +164,7 @@ MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
return true; return true;
} }
struct find_mlir_op struct find_mlir_fused_ops
{ {
auto matcher() const auto matcher() const
{ {
...@@ -163,34 +193,6 @@ struct find_mlir_op ...@@ -163,34 +193,6 @@ struct find_mlir_op
return ins_map; return ins_map;
} }
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
// Whitelist supported fusion options, including imposing type constraints // Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function) // for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types. // on particular types.
...@@ -210,42 +212,46 @@ struct find_mlir_op ...@@ -210,42 +212,46 @@ struct find_mlir_op
return false; return false;
} }
const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"}; const std::initializer_list<std::string> any_type_ops = {"@literal", "@param", "@return"};
const std::initializer_list<std::string> no_bool_ops = {"convolution", const std::initializer_list<std::string> no_bool_ops = {
"quant_convolution", "convolution",
"dot", "quant_convolution",
"quant_dot", "dot",
"add", "quant_dot",
"clip", "add",
"relu", "clip",
"sub", "relu",
"mul", "sub",
"div", "mul",
"pow", "div",
"where", "pow",
"quantizelinear", "where",
"dequantizelinear", "quantizelinear",
"abs", "dequantizelinear",
"neg"}; "abs",
const std::initializer_list<std::string> fp_only_ops = {"ceil", "neg",
"erf", };
"exp", const std::initializer_list<std::string> fp_only_ops = {
"floor", "ceil",
"log", "erf",
"recip", "exp",
"rsqrt", "floor",
"sigmoid" "log",
"softmax", "recip",
"tanh"}; "rsqrt",
"sigmoid",
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type); bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
if(contains(any_type_ops, name)) if(contains(any_type_ops, name))
return true; return true;
if(result_type != type_t::bool_type && contains(no_bool_ops, name)) if(result_type != type_t::bool_type and contains(no_bool_ops, name))
return true; return true;
if(is_float && contains(fp_only_ops, name)) if(is_float and contains(fp_only_ops, name))
return true; return true;
// Only conversions between floating types are known to be unambigiously // Only conversions between floating types are known to be unambigiously
// supported. // supported.
if(is_float && name == "convert") if(is_float and name == "convert")
{ {
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) { return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type()); return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
...@@ -296,20 +302,115 @@ struct find_mlir_op ...@@ -296,20 +302,115 @@ struct find_mlir_op
} }
}; };
struct find_mlir_standalone_op
{
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto conv_based_op = r.result;
// enable only for fp32/fp16/i8 types
if(std::any_of(conv_based_op->inputs().begin(), conv_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
module_ref mm = mpm.create_module("mlir_" + std::to_string(counter++));
mm->set_bypass();
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, conv_based_op);
mm->add_return({anchor_op});
mpm.get_module().replace_instruction(
conv_based_op, mlir_op{conv_based_op->get_operator()}, top_inputs, {mm});
}
};
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{
auto matcher() const { return match::name("convolution"); }
};
struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::name("dot"); }
};
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS);
bool is_self_decide() { return string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "").empty(); }
bool is_requested(std::string_view option)
{
assert(not is_self_decide());
auto string_value = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, "");
const auto options = split_string(string_value, ',');
return contains(options, option);
}
bool is_enabled(std::string_view op_name, context* ctx)
{
if(is_self_decide())
{
if(op_name == "fused")
{
return true;
}
else if(op_name == "convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
}
else
{
return false;
}
}
return is_requested(op_name);
}
} // namespace } // namespace
#endif #endif // MIGRAPHX_MLIR
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
match::find_matches(mpm, find_mlir_op{}); if(is_enabled("fused", this->ctx))
{
match::find_matches(mpm, find_mlir_fused_ops{});
}
if(is_enabled("convolution", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_convolution_op{});
}
if(is_enabled("dot", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_dot_op{});
}
#else #else
(void)mpm; (void)mpm;
#endif #endif
} }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr) ...@@ -55,7 +55,7 @@ bool is_device_ptr(const void* ptr)
auto status = hipPointerGetAttributes(&attr, ptr); auto status = hipPointerGetAttributes(&attr, ptr);
if(status != hipSuccess) if(status != hipSuccess)
return false; return false;
return attr.memoryType == hipMemoryTypeDevice; return attr.type == hipMemoryTypeDevice;
} }
std::size_t get_available_gpu_memory() std::size_t get_available_gpu_memory()
......
...@@ -58,6 +58,8 @@ struct hiprtc_src_file ...@@ -58,6 +58,8 @@ struct hiprtc_src_file
} }
}; };
MIGRAPHX_GPU_EXPORT bool hip_has_flags(const std::vector<std::string>& flags);
MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc( MIGRAPHX_GPU_EXPORT std::vector<std::vector<char>> compile_hip_src_with_hiprtc(
std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch); std::vector<hiprtc_src_file> srcs, std::string params, const std::string& arch);
......
...@@ -46,13 +46,7 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy); ...@@ -46,13 +46,7 @@ using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
struct hip_device struct hip_device
{ {
hip_device() hip_device() : device_props{} { add_stream(); }
{
device_props.gcnArchName[0] = '\0';
device_props.gcnArch = 0;
device_props.multiProcessorCount = 0;
add_stream();
}
hip_device(std::size_t id, std::size_t n) : device_id(id) hip_device(std::size_t id, std::size_t n) : device_id(id)
{ {
...@@ -171,7 +165,7 @@ struct hip_device ...@@ -171,7 +165,7 @@ struct hip_device
std::size_t stream_id() const { return current_stream; } std::size_t stream_id() const { return current_stream; }
std::string get_device_name() const { return get_arch_name(device_props); } std::string get_device_name() const { return device_props.gcnArchName; }
std::string get_gfx_name() const { return trim(split_string(get_device_name(), ':').front()); } std::string get_gfx_name() const { return trim(split_string(get_device_name(), ':').front()); }
......
...@@ -84,8 +84,10 @@ struct miopen_convolution ...@@ -84,8 +84,10 @@ struct miopen_convolution
{ {
check_shapes{inputs, op}.has(4); check_shapes{inputs, op}.has(4);
std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2); std::vector<shape> conv_inputs(inputs.begin(), inputs.begin() + 2);
check_shapes{conv_inputs, *this}.max_ndims(5).packed_layouts( check_shapes{conv_inputs, *this}
{{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}}); .max_ndims(5)
.packed_layouts({{0, 1, 2}, {0, 1, 2, 3}, {0, 2, 3, 1}, {0, 1, 2, 3, 4}})
.same_layout();
return migraphx::compute_shape<Op>(op, conv_inputs); return migraphx::compute_shape<Op>(op, conv_inputs);
} }
......
...@@ -33,8 +33,6 @@ namespace migraphx { ...@@ -33,8 +33,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT std::string get_arch_name(const hipDeviceProp_t& props);
MIGRAPHX_GPU_EXPORT std::string get_device_name(); MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT int get_device_id(); MIGRAPHX_GPU_EXPORT int get_device_id();
......
...@@ -92,7 +92,7 @@ struct hip_sync_stream ...@@ -92,7 +92,7 @@ struct hip_sync_stream
return inputs.front(); return inputs.front();
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(const context& ctx, const shape&, const std::vector<argument>& args) const
{ {
gpu_sync(ctx); gpu_sync(ctx);
if(args.empty()) if(args.empty())
......
...@@ -37,7 +37,7 @@ struct module; ...@@ -37,7 +37,7 @@ struct module;
namespace gpu { namespace gpu {
MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m); MIGRAPHX_GPU_EXPORT std::string dump_mlir(const module& m);
MIGRAPHX_GPU_EXPORT code_object_op compile_mlir(const context& ctx, MIGRAPHX_GPU_EXPORT code_object_op compile_mlir(const context& migraphx_ctx,
module m, module m,
const std::vector<instruction_ref>& inputs, const std::vector<instruction_ref>& inputs,
const value& solution); const value& solution);
...@@ -47,8 +47,10 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m, ...@@ -47,8 +47,10 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
code_object_op co, code_object_op co,
const std::vector<instruction_ref>& inputs); const std::vector<instruction_ref>& inputs);
MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(module m, MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
const std::vector<shape>& inputs); module m,
const std::vector<shape>& inputs,
bool exhaustive);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -300,7 +300,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -300,7 +300,8 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const auto& b_shape = inputs[1]; const auto& b_shape = inputs[1];
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto rank = a_shape.lens().size(); // cppcheck-suppress unreadVariable
auto rank = a_shape.ndim();
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto m = c_shape.lens()[rank - 2]; auto m = c_shape.lens()[rank - 2];
......
...@@ -37,7 +37,7 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -37,7 +37,7 @@ struct mlir_compiler : compiler<mlir_compiler>
operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; } operation compile_op(context&, const std::vector<shape>&, const value&) const { return {}; }
compiler_replace compiler_replace
compile(context& ctx, instruction_ref ins, const operation&, const value& solution) const compile(const context& ctx, instruction_ref ins, const operation&, const value& solution) const
{ {
auto* smod = ins->module_inputs().front(); auto* smod = ins->module_inputs().front();
assert(smod->get_parameter_names().size() == ins->inputs().size() - 1); assert(smod->get_parameter_names().size() == ins->inputs().size() - 1);
...@@ -52,14 +52,14 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -52,14 +52,14 @@ struct mlir_compiler : compiler<mlir_compiler>
}}; }};
} }
optional<tuning_config> optional<tuning_config> get_tuning_config(const context& ctx,
get_tuning_config(context&, instruction_ref ins, const operation&, bool exhaustive) const instruction_ref ins,
const operation&,
bool exhaustive) const
{ {
if(not exhaustive)
return nullopt;
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front(); auto* smod = ins->module_inputs().front();
return get_tuning_config_mlir(*smod, shapes); return get_tuning_config_mlir(ctx, *smod, shapes, exhaustive);
} }
}; };
......
...@@ -81,7 +81,7 @@ struct roialign_compiler : compiler<roialign_compiler> ...@@ -81,7 +81,7 @@ struct roialign_compiler : compiler<roialign_compiler>
// coord_trans_mode // coord_trans_mode
auto ctm = v.at("coordinate_transformation_mode").to<std::string>(); auto ctm = v.at("coordinate_transformation_mode").to<std::string>();
float rois_offset = (ctm == "output_half_pixel") ? -0.5f : 0.0f; float rois_offset = (ctm == "half_pixel") ? -0.5f : 0.0f;
options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset); options.params += " -DROIS_OFFSET=" + std::to_string(rois_offset);
// spatial_scale // spatial_scale
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include "migraphx/make_op.hpp" #include "migraphx/make_op.hpp"
#include <migraphx/stringutils.hpp>
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
...@@ -36,7 +37,10 @@ ...@@ -36,7 +37,10 @@
#include <mutex> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3 #if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling" #warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck
#ifndef CPPCHECK
#undef MIGRAPHX_MLIR #undef MIGRAPHX_MLIR
#endif
#else #else
#include <mlir-c/RegisterRocMLIR.h> #include <mlir-c/RegisterRocMLIR.h>
#endif #endif
...@@ -50,6 +54,7 @@ ...@@ -50,6 +54,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/gpu/code_object_op.hpp> #include <migraphx/gpu/code_object_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/perfdb.hpp> #include <migraphx/gpu/perfdb.hpp>
#include <migraphx/gpu/tuning_config.hpp> #include <migraphx/gpu/tuning_config.hpp>
...@@ -65,6 +70,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -65,6 +70,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_MLIR);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_DB);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_TUNING_CFG);
...@@ -89,6 +95,8 @@ struct mlir_handle ...@@ -89,6 +95,8 @@ struct mlir_handle
friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); } friend bool operator==(ptr x, ptr y) { return x.get_value() == y.get_value(); }
friend bool operator!=(ptr x, ptr y) { return not(x == y); } friend bool operator!=(ptr x, ptr y) { return not(x == y); }
explicit operator bool() const noexcept { return obj != ptr(); }
T obj{}; T obj{};
}; };
...@@ -172,12 +180,6 @@ std::string mlir_print(F f, T x) ...@@ -172,12 +180,6 @@ std::string mlir_print(F f, T x)
return ss.str(); return ss.str();
} }
bool has_xdlops(const std::string& target_arch)
{
const auto device_name = trim(split_string(target_arch, ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx908");
}
struct mlir_program struct mlir_program
{ {
mlir_program() mlir_program()
...@@ -512,7 +514,8 @@ struct mlir_program ...@@ -512,7 +514,8 @@ struct mlir_program
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)}, ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", sym_name}, {"sym_name", sym_name},
{"kernel", std::string("mixr")}, {"kernel", std::string("mixr")},
{"arch", target_arch}}); {"arch", target_arch},
{"num_cu", num_cu}});
ops.add_region(std::move(region)); ops.add_region(std::move(region));
insert(body, std::move(ops)); insert(body, std::move(ops));
...@@ -559,14 +562,7 @@ struct mlir_program ...@@ -559,14 +562,7 @@ struct mlir_program
static std::string get_symbol_name(const module& m) static std::string get_symbol_name(const module& m)
{ {
for(auto ins : iterator_for(m)) return "mlir_" + gen::generate_name_from_ops(m);
{
if(ins->name() == "convolution" or ins->name() == "dot")
{
return "mlir_" + ins->name();
}
}
return "main";
} }
void parse(const module& m) void parse(const module& m)
...@@ -602,9 +598,6 @@ struct mlir_program ...@@ -602,9 +598,6 @@ struct mlir_program
{ {
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops
if(has_xdlops(target_arch))
ops.add_attributes({{"xdlopsV2", true}});
} }
std::vector<MlirValue> inputs; std::vector<MlirValue> inputs;
...@@ -653,7 +646,12 @@ struct mlir_program ...@@ -653,7 +646,12 @@ struct mlir_program
return op; return op;
} }
void find_target() { target_arch = get_device_name(); } void set_gpu_properties(const context& migraphx_ctx)
{
const auto& device = migraphx_ctx.get_current_device();
target_arch = device.get_device_name();
num_cu = device.get_cu_count();
}
std::pair<std::size_t, std::size_t> get_launch_params() const std::pair<std::size_t, std::size_t> get_launch_params() const
{ {
...@@ -667,7 +665,7 @@ struct mlir_program ...@@ -667,7 +665,7 @@ struct mlir_program
value::binary get_binary() const value::binary get_binary() const
{ {
int size = 0; size_t size = 0;
mlirGetBinary(mmodule.get(), &size, nullptr); mlirGetBinary(mmodule.get(), &size, nullptr);
value::binary result(size); value::binary result(size);
if(mlirGetBinary(mmodule.get(), &size, reinterpret_cast<char*>(result.data()))) if(mlirGetBinary(mmodule.get(), &size, reinterpret_cast<char*>(result.data())))
...@@ -675,30 +673,44 @@ struct mlir_program ...@@ -675,30 +673,44 @@ struct mlir_program
MIGRAPHX_THROW("Failed to compile mlir program"); MIGRAPHX_THROW("Failed to compile mlir program");
} }
void set_tuning(const value& v) void set_tuning(const value& v) MIGRAPHX_TIDY_CONST
{ {
auto str = v.to<std::string>(); const auto* str = v.if_string();
// We need to make a copy of the buffer since mlirRockTuningSetFromStr may modify the string if(str == nullptr)
std::vector<char> buffer(str.begin(), str.end()); MIGRAPHX_THROW("mlir tuning solutions must be strings");
buffer.push_back(0); if(not mlirRockTuningSetFromStr(mmodule.get(), make_mlir_string_ref(*str)))
if(not mlirRockTuningSetFromStr(mmodule.get(), buffer.data())) MIGRAPHX_THROW("Failed setting tuning key: " + *str);
MIGRAPHX_THROW("Failed setting tuning key: " + str);
} }
tuning_config get_tuning_config() MIGRAPHX_TIDY_CONST tuning_config get_tuning_config(bool exhaustive) MIGRAPHX_TIDY_CONST
{ {
tuning_config tc; tuning_config tc;
run_high_level_pipeline(); run_high_level_pipeline();
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get())}; auto tuning_mode =
for(auto i : range(mlirRockTuningGetNumParamsFull(params.get()))) exhaustive ? RocmlirTuningParamSetKindFull : RocmlirTuningParamSetKindQuick;
if(enabled(MIGRAPHX_MLIR_TUNE_EXHAUSTIVE{}))
tuning_mode = RocmlirTuningParamSetKindExhaustive;
mlir_tuning_space params{mlirRockTuningSpaceCreate(mmodule.get(), tuning_mode)};
for(auto i : range(mlirRockTuningGetNumParams(params.get())))
{ {
mlir_tuning_param param{mlirRockTuningParamCreate()}; mlir_tuning_param param{mlirRockTuningParamCreate()};
if(not mlirRockTuningParamGet(params.get(), i, param.get())) if(not mlirRockTuningParamGet(params.get(), i, param.get()))
MIGRAPHX_THROW("Incorrect mlir tuning parameter: " + std::to_string(i)); MIGRAPHX_THROW("Incorrect mlir tuning parameter: " + std::to_string(i));
tc.solutions.push_back(std::string{mlirRockTuningGetParamStr(param.get())}); std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> perf_key;
size_t perf_key_bytes =
mlirRockTuningParamToString(param.get(), perf_key.data(), perf_key.size());
if(perf_key_bytes > perf_key.size())
MIGRAPHX_THROW("Tuning perf key was " + std::to_string(perf_key_bytes) +
" bytes and thus too long");
tc.solutions.emplace_back(perf_key.begin(), perf_key.begin() + perf_key_bytes);
} }
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()}; std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> tuning_key;
tc.problem = std::string{mlirRockTuningGetKey(tuning_table.get(), mmodule.get())}; size_t tuning_key_bytes =
mlirRockTuningGetKey(mmodule.get(), tuning_key.data(), tuning_key.size());
if(tuning_key_bytes > tuning_key.size())
MIGRAPHX_THROW("Tuning table key was " + std::to_string(tuning_key_bytes) +
" bytes and thus too long");
tc.problem = std::string(tuning_key.begin(), tuning_key.begin() + tuning_key_bytes);
return tc; return tc;
} }
...@@ -706,13 +718,14 @@ struct mlir_program ...@@ -706,13 +718,14 @@ struct mlir_program
// This function appends to tuning cfg file that could be // This function appends to tuning cfg file that could be
// used with rocMLIR tuning scripts. // used with rocMLIR tuning scripts.
void dump_tuning_cfg(const char* prob_config) const void dump_tuning_cfg(const std::string& prob_config) const
{ {
std::string tuning_cfg_path = string_value_of(MIGRAPHX_MLIR_TUNING_CFG{}); std::string tuning_cfg_path = string_value_of(MIGRAPHX_MLIR_TUNING_CFG{});
if(!tuning_cfg_path.empty()) if(not tuning_cfg_path.empty())
{ {
std::vector<std::string> tokens = split_string(prob_config, '\t'); std::vector<std::string> tokens = split_string(prob_config, '\t');
std::string prob = tokens[1]; std::string prob = tokens[2];
if(starts_with(prob, "conv")) if(starts_with(prob, "conv"))
{ {
tuning_cfg_path += ".conv"; tuning_cfg_path += ".conv";
...@@ -722,55 +735,72 @@ struct mlir_program ...@@ -722,55 +735,72 @@ struct mlir_program
tuning_cfg_path += ".gemm"; tuning_cfg_path += ".gemm";
} }
std::ofstream tuning_cfg(tuning_cfg_path, std::ios::app); std::ofstream tuning_cfg(tuning_cfg_path, std::ios::app);
prob =
trim(prob, [](unsigned char c) { return (c == '\0') or (std::isspace(c) != 0); });
tuning_cfg << prob << std::endl; tuning_cfg << prob << std::endl;
} }
} }
static mlir_tuning_table create_tuning_table() static std::pair<mlir_tuning_table, bool> load_tuning_table()
{ {
mlir_tuning_table tuning_table{mlirRockTuningTableCreate()}; mlir_tuning_table tuning_table{mlirRockTuningTableCreate()};
bool found_table = false;
std::string tuning_db_path = string_value_of(MIGRAPHX_MLIR_TUNING_DB{}); std::string tuning_db_path = string_value_of(MIGRAPHX_MLIR_TUNING_DB{});
if(!tuning_db_path.empty()) if(not tuning_db_path.empty())
{ {
std::ifstream tuning_db_tsv(tuning_db_path); std::ifstream tuning_db_tsv(tuning_db_path);
if(tuning_db_tsv) if(tuning_db_tsv)
{ {
found_table = true;
std::string line; std::string line;
while(std::getline(tuning_db_tsv, line)) while(std::getline(tuning_db_tsv, line))
{ {
std::vector<std::string> tokens = split_string(line, '\t'); std::vector<std::string> tokens = split_string(line, '\t');
std::string arch = tokens[0]; std::string arch = tokens[0];
std::string prob = tokens[1]; std::string num_cu = tokens[1];
std::string perf = tokens[2]; std::string prob = tokens[2];
std::string key = arch.append("\t").append(prob); std::string perf = tokens[3];
mlirRockTuningUpdateTable(tuning_table.get(), key.c_str(), perf.c_str(), 1.0); std::string key = arch.append("\t").append(num_cu).append("\t").append(prob);
mlirRockTuningUpdateTable(tuning_table.get(),
make_mlir_string_ref(key),
make_mlir_string_ref(perf),
1.0);
} }
} }
} }
else else
{ {
found_table = false;
std::cerr std::cerr
<< "WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for " << "WARNING: MLIR tuning db not found. Please set MIGRAPHX_MLIR_TUNING_DB for "
"optimal performance." "optimal performance."
<< std::endl; << std::endl;
} }
return tuning_table; return std::make_pair(std::move(tuning_table), found_table);
} }
bool get_module_tuned() const bool get_module_tuned() const
{ {
static mlir_tuning_table tuning_table = create_tuning_table(); static std::pair<mlir_tuning_table, bool> tuning_table = load_tuning_table();
// The tuning table as currently implemented is currently not if(not mlirRockTuningSetFromTable(tuning_table.first.get(), mmodule.get()))
// thread safe. This will be fixed in the future. For now,
// stick a mutex around all tuning table interaction.
static std::mutex lock;
std::lock_guard<std::mutex> guard(lock);
if(!mlirRockTuningSetFromTable(tuning_table.get(), mmodule.get()))
{ {
const char* prob_config = mlirRockTuningGetKey(tuning_table.get(), mmodule.get()); std::array<char, ROCMLIR_TUNING_KEY_BUFSZ> prob_config;
std::stringstream key(prob_config); size_t prob_config_bytes =
std::cerr << "fails to set param on" << prob_config << std::endl; mlirRockTuningGetKey(mmodule.get(), prob_config.data(), prob_config.size());
dump_tuning_cfg(prob_config); if(prob_config_bytes >= prob_config.size())
{
std::cerr << "MLIR tuning key overflowed buffer, needed " << prob_config_bytes
<< " bytes" << std::endl;
return false;
}
std::string prob_config_str(prob_config.begin(),
prob_config.begin() + prob_config_bytes);
if(tuning_table.second)
{
std::cerr << "NOTE: MLIR tuning table did not include a key for " << prob_config_str
<< std::endl;
}
dump_tuning_cfg(prob_config_str);
return false; return false;
} }
return true; return true;
...@@ -781,7 +811,8 @@ struct mlir_program ...@@ -781,7 +811,8 @@ struct mlir_program
mlir_module mmodule; mlir_module mmodule;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
std::string target_arch; std::string target_arch = "";
std::size_t num_cu = 0;
std::string sym_name; std::string sym_name;
}; };
...@@ -838,7 +869,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs) ...@@ -838,7 +869,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
} }
} }
code_object_op compile_mlir(const context&, code_object_op compile_mlir(const context& migraphx_ctx,
module m, module m,
const std::vector<instruction_ref>& inputs, const std::vector<instruction_ref>& inputs,
const value& solution) const value& solution)
...@@ -846,15 +877,22 @@ code_object_op compile_mlir(const context&, ...@@ -846,15 +877,22 @@ code_object_op compile_mlir(const context&,
adjust_param_shapes(m, to_shapes(inputs)); adjust_param_shapes(m, to_shapes(inputs));
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
static std::mutex mutex;
if(trace) if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << m << std::endl; std::cout << m << std::endl;
}
mlir_program mp; mlir_program mp;
mp.find_target(); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
auto mod_op = mlirModuleGetOperation(mp.mmodule.get()); auto mod_op = mlirModuleGetOperation(mp.mmodule.get());
if(trace) if(trace)
{
const std::lock_guard<std::mutex> lock(mutex);
std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl; std::cout << mlir_print(&mlirOperationPrint, mod_op) << std::endl;
}
auto co = mp.compile(solution); auto co = mp.compile(solution);
co.expected_inputs = to_shapes(inputs); co.expected_inputs = to_shapes(inputs);
co.output = m.get_output_shapes().front(); co.output = m.get_output_shapes().front();
...@@ -877,14 +915,17 @@ instruction_ref insert_mlir(module& m, ...@@ -877,14 +915,17 @@ instruction_ref insert_mlir(module& m,
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
} }
tuning_config get_tuning_config_mlir(module m, const std::vector<shape>& inputs) tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
module m,
const std::vector<shape>& inputs,
bool exhaustive)
{ {
adjust_param_shapes(m, inputs); adjust_param_shapes(m, inputs);
mlir_program mp; mlir_program mp;
mp.find_target(); mp.set_gpu_properties(migraphx_ctx);
mp.parse(m); mp.parse(m);
return mp.get_tuning_config(); return mp.get_tuning_config(exhaustive);
} }
#else #else
...@@ -909,10 +950,14 @@ instruction_ref ...@@ -909,10 +950,14 @@ instruction_ref
insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&) insert_mlir(module& m, instruction_ref, code_object_op co, const std::vector<instruction_ref>&)
{ {
use(co); use(co);
use(m);
return m.end(); return m.end();
} }
tuning_config get_tuning_config_mlir(module, const std::vector<shape>&) { return {}; } tuning_config get_tuning_config_mlir(const context&, module, const std::vector<shape>&, bool)
{
return {};
}
// NOLINTEND(performance-unnecessary-value-param) // NOLINTEND(performance-unnecessary-value-param)
#endif #endif
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifdef __HIP_DEVICE_COMPILE__
#error \
"Device compilation not allowed for migraphx_gpu. Do not link with hip::device. Device code should go into migraphx_device or migraphx_kernels"
#endif
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp> #include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
...@@ -45,40 +46,42 @@ struct layernorm_base ...@@ -45,40 +46,42 @@ struct layernorm_base
} }
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{ {
std::size_t nargs = 1; std::size_t nargs = N;
if(not mods.empty()) if(not mods.empty())
{ {
auto* pm = mods.front(); auto* pm = mods.front();
nargs = pm->get_parameter_names().size(); nargs += pm->get_parameter_names().size() - 1;
} }
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs);
auto s = inputs.at(0); auto s = inputs.front();
auto t = s.type(); auto t = s.type();
if(not mods.empty()) if(not mods.empty())
t = mods.front()->get_output_shapes().front().type(); t = mods.front()->get_output_shapes().front().type();
if(s.scalar())
{ // Scalar output if all inputs are scalar
return s; if(inputs.front().elements() == 1 and
} all_of(inputs, [](const auto& ss) { return ss.scalar(); }))
else if(s.broadcasted()) return inputs.front();
{ auto l_s = shape::from_permutation(
return {t, s.lens()}; t, s.lens(), find_permutation(std::vector<shape>(inputs.begin(), inputs.begin() + N)));
} // just prelayernorm or preadd_layernorm
else if(nargs <= N)
{ return l_s;
return s.with_lens(t, s.lens()); // else, layernorm + pointwise fusion, preserve layout of fused op
} std::vector<shape> lp_s(inputs.begin() + N, inputs.end());
lp_s.insert(lp_s.begin(), l_s);
return shape::from_permutation(t, s.lens(), find_permutation(lp_s));
} }
}; };
struct layernorm : layernorm_base<layernorm, 0> struct layernorm : layernorm_base<layernorm, 1>
{ {
std::string name() const { return "gpu::prelayernorm"; } std::string name() const { return "gpu::prelayernorm"; }
}; };
MIGRAPHX_REGISTER_OP(layernorm); MIGRAPHX_REGISTER_OP(layernorm);
struct add_layernorm : layernorm_base<add_layernorm, 1> struct add_layernorm : layernorm_base<add_layernorm, 2>
{ {
std::string name() const { return "gpu::preadd_layernorm"; } std::string name() const { return "gpu::preadd_layernorm"; }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment