Commit acad34c6 authored by charlie's avatar charlie
Browse files

Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into refactor_dynamic_compute

parents 65e14286 70e63960
...@@ -33,7 +33,7 @@ def rocmtestnode(Map conf) { ...@@ -33,7 +33,7 @@ def rocmtestnode(Map conf) {
} }
} }
node(name) { node(name) {
withEnv(['HSA_ENABLE_SDMA=0', 'MIOPEN_DEBUG_GCN_ASM_KERNELS=0']) { withEnv(['HSA_ENABLE_SDMA=0']) {
stage("checkout ${variant}") { stage("checkout ${variant}") {
checkout scm checkout scm
} }
......
...@@ -40,7 +40,6 @@ struct fmod : binary<fmod> ...@@ -40,7 +40,6 @@ struct fmod : binary<fmod>
a["commutative"] = false; a["commutative"] = false;
return a; return a;
} }
std::string point_function() const { return "fmod"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return std::fmod(x, y); }; return [](auto x, auto y) { return std::fmod(x, y); };
......
...@@ -38,9 +38,9 @@ struct mod : binary<mod> ...@@ -38,9 +38,9 @@ struct mod : binary<mod>
{ {
auto a = base_attributes(); auto a = base_attributes();
a["commutative"] = false; a["commutative"] = false;
a["point_op"] = "${function:fmod}((${function:remainder}(${0}, ${1})) + ${1}, ${1})";
return a; return a;
} }
std::string point_function() const { return "mod"; }
auto apply() const auto apply() const
{ {
return [](auto x, auto y) { return std::fmod((std::remainder(x, y)) + y, y); }; return [](auto x, auto y) { return std::fmod((std::remainder(x, y)) + y, y); };
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include <migraphx/onnx/op_parser.hpp> #include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/batch_norm_inference.hpp> #include <migraphx/instruction.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -36,28 +36,63 @@ struct parse_batchnorm : op_parser<parse_batchnorm> ...@@ -36,28 +36,63 @@ struct parse_batchnorm : op_parser<parse_batchnorm>
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const std::vector<instruction_ref> args) const
{ {
float epsilon = 1e-5f; float epsilon = 1e-5f;
float momentum = 0.9f;
op::batch_norm_inference::bn_infer_mode_t bn_mode = op::batch_norm_inference::spatial;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
} }
if(contains(info.attributes, "momentum")) auto x_lens = args[0]->get_shape().lens();
auto x_type = args[0]->get_shape().type();
if(std::any_of(args.cbegin() + 1, args.cend(), [](auto a) {
return a->get_shape().lens().size() != 1;
}))
{
MIGRAPHX_THROW("PARSE_BATCHNORM: argument scale, bias, mean, or var rank != 1");
}
if(x_lens.size() == 1)
{
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto n0 = info.add_broadcastable_binary_op("sub", args[0], args[3]);
auto d0 = info.add_broadcastable_binary_op("add", args[4], eps);
auto d1 = info.add_broadcastable_binary_op("pow", d0, rt);
auto div0 = info.add_broadcastable_binary_op("div", n0, d1);
auto r0 = info.add_broadcastable_binary_op("mul", div0, args[1]);
return info.add_broadcastable_binary_op("add", r0, args[2]);
}
else if(x_lens.size() > 2)
{ {
momentum = parser.parse_value(info.attributes.at("momentum")).at<float>(); // unsqueeze tensors of shape (C) to broadcast correctly
std::vector<int64_t> unsqueeze_axes(x_lens.size() - 2);
std::iota(unsqueeze_axes.begin(), unsqueeze_axes.end(), 1);
auto rt = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0.5}});
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {epsilon}});
auto scale_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[1]);
auto bias_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[2]);
auto mean_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[3]);
auto var_unsqueeze = info.add_instruction(
migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), args[4]);
auto numer = info.add_broadcastable_binary_op("sub", args[0], mean_unsqueeze);
auto var_eps = info.add_broadcastable_binary_op("add", var_unsqueeze, eps);
auto denom = info.add_broadcastable_binary_op("pow", var_eps, rt);
auto div0 = info.add_broadcastable_binary_op("div", numer, denom);
auto r0 = info.add_broadcastable_binary_op("mul", div0, scale_unsqueeze);
return info.add_broadcastable_binary_op("add", r0, bias_unsqueeze);
} }
if(contains(info.attributes, "spatial")) else
{ {
bn_mode = (parser.parse_value(info.attributes.at("spatial")).at<uint64_t>() > 0) // num dims either 0 or 2
? op::batch_norm_inference::spatial MIGRAPHX_THROW("PARSE_BATCHNORM: rank " + std::to_string(x_lens.size()) +
: op::batch_norm_inference::per_activation; " input tensor, unhandled data format");
} }
op::batch_norm_inference op{epsilon, momentum, bn_mode};
return info.add_instruction(op, args);
} }
}; };
......
...@@ -35,6 +35,7 @@ add_library(migraphx_cpu ...@@ -35,6 +35,7 @@ add_library(migraphx_cpu
dnnl.cpp dnnl.cpp
eltwise.cpp eltwise.cpp
erf.cpp erf.cpp
fmod.cpp
fuse_ops.cpp fuse_ops.cpp
gather.cpp gather.cpp
gemm.cpp gemm.cpp
...@@ -42,6 +43,7 @@ add_library(migraphx_cpu ...@@ -42,6 +43,7 @@ add_library(migraphx_cpu
logsoftmax.cpp logsoftmax.cpp
lowering.cpp lowering.cpp
lrn.cpp lrn.cpp
mod.cpp
preallocate.cpp preallocate.cpp
pooling.cpp pooling.cpp
reduction.cpp reduction.cpp
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/config.hpp>
#include <migraphx/cpu/pointwise.hpp>
#include <migraphx/op/fmod.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
template struct cpu_binary<op::fmod>;
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -43,6 +43,8 @@ ...@@ -43,6 +43,8 @@
#include <migraphx/op/argmax.hpp> #include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp> #include <migraphx/op/argmin.hpp>
#include <migraphx/op/rnn_var_sl_last_output.hpp> #include <migraphx/op/rnn_var_sl_last_output.hpp>
#include <migraphx/op/mod.hpp>
#include <migraphx/op/fmod.hpp>
#include <migraphx/shape_for_each.hpp> #include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp> #include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp> #include <migraphx/par_dfor.hpp>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/config.hpp>
#include <migraphx/cpu/pointwise.hpp>
#include <migraphx/op/mod.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace cpu {
template struct cpu_binary<op::mod>;
} // namespace cpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/gpu/compile_gen.hpp> #include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
...@@ -48,12 +49,13 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs) ...@@ -48,12 +49,13 @@ static std::vector<std::size_t> vector_sizes(const std::vector<shape>& inputs)
return {4, 2}; return {4, 2};
} }
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs) vectorize vectorize::elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes)
{ {
if(std::all_of( if(std::all_of(
inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; })) inputs.begin(), inputs.end(), [&](const auto& s) { return s.lens()[axis] == 1; }))
return {1, axis}; return {1, axis};
auto sizes = vector_sizes(inputs);
std::vector<std::size_t> max_vec_size; std::vector<std::size_t> max_vec_size;
std::transform(inputs.begin(), std::transform(inputs.begin(),
inputs.end(), inputs.end(),
...@@ -81,6 +83,33 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs ...@@ -81,6 +83,33 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis}; return {*std::min_element(max_vec_size.begin(), max_vec_size.end()), axis};
} }
vectorize vectorize::elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs)
{
if(inputs.empty())
return {1, axis};
std::size_t n = std::max_element(inputs.begin(),
inputs.end(),
by(std::less<>{}, [](const auto& s) { return s.elements(); }))
->elements();
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
std::size_t over = n / max_global;
bool broadcasted =
std::any_of(inputs.begin(), inputs.end(), [](const auto& s) { return s.broadcasted(); });
std::vector<std::size_t> sizes;
if(broadcasted and over > 8)
sizes.push_back(8);
if(over > 4)
sizes.push_back(4);
sizes.push_back(2);
return elements(axis, inputs, sizes);
}
vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs)
{
return elements(axis, inputs, vector_sizes(inputs));
}
std::string vectorize::str() const std::string vectorize::str() const
{ {
return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()"; return "vectorize<" + to_string(size) + ", " + to_string(axis) + ">()";
......
...@@ -36,6 +36,9 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -36,6 +36,9 @@ inline namespace MIGRAPHX_INLINE_NS {
struct shape; struct shape;
namespace gpu { namespace gpu {
struct context;
namespace gen { namespace gen {
struct vectorize struct vectorize
...@@ -43,6 +46,10 @@ struct vectorize ...@@ -43,6 +46,10 @@ struct vectorize
std::size_t size = 1; std::size_t size = 1;
std::size_t axis = 0; std::size_t axis = 0;
static vectorize elements(std::size_t axis, const std::vector<shape>& inputs); static vectorize elements(std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(context& ctx, std::size_t axis, const std::vector<shape>& inputs);
static vectorize elements(std::size_t axis,
const std::vector<shape>& inputs,
const std::vector<std::size_t>& sizes);
std::string str() const; std::string str() const;
}; };
struct preload struct preload
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
*/ */
#ifndef MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP #ifndef MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#define MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP #define MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#include <migraphx/manage_ptr.hpp> #include <migraphx/manage_ptr.hpp>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <rocblas.h> #include <rocblas.h>
...@@ -37,6 +36,11 @@ using rocblas_handle_ptr = MIGRAPHX_MANAGE_PTR(rocblas_handle, rocblas_destroy_h ...@@ -37,6 +36,11 @@ using rocblas_handle_ptr = MIGRAPHX_MANAGE_PTR(rocblas_handle, rocblas_destroy_h
rocblas_handle_ptr create_rocblas_handle_ptr(); rocblas_handle_ptr create_rocblas_handle_ptr();
rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s); rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s);
struct context;
bool get_compute_fp32_flag();
bool get_int8_x4_format(context& ctx);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
...@@ -74,7 +74,7 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -74,7 +74,7 @@ struct concat_compiler : compiler<concat_compiler>
options.output = inputs.back(); options.output = inputs.back();
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.inputs); auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(axis, options.inputs); auto vec = vectorize::elements(ctx, axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel"); options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256)); v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
......
...@@ -50,7 +50,6 @@ ${preamble} ...@@ -50,7 +50,6 @@ ${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, ${eps}, xs...); ${layernorm}<${axis}>(${post}, ${eps}, xs...);
}); });
...@@ -78,9 +77,8 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -78,9 +77,8 @@ struct layernorm_compiler : compiler<layernorm_compiler>
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(axis == faxis) if(axis == faxis)
{ {
vec = vectorize::elements(faxis, inputs); vec = vectorize::elements(ctx, faxis, inputs);
} }
auto preloads = preload::broadcasts(axis, inputs);
auto relements = inputs[0].lens()[axis] / vec.size; auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]); auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
auto block_size = compute_block_size(relements, 256); auto block_size = compute_block_size(relements, 256);
...@@ -96,7 +94,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -96,7 +94,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(preloads, vec)}, {"transformers", make_transformer_args(vec)},
{"post", v.get("post", std::string{"op::id{}"})}, {"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})}, {"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})}, {"layernorm", v.get("layernorm", std::string{"layernorm"})},
......
...@@ -75,20 +75,16 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -75,20 +75,16 @@ struct pointwise_compiler : compiler<pointwise_compiler>
options.virtual_inputs = reduce_dims(inputs); options.virtual_inputs = reduce_dims(inputs);
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
auto axis = find_fast_axis(options.virtual_inputs); auto axis = find_fast_axis(options.virtual_inputs);
auto vec = vectorize::elements(axis, options.virtual_inputs); auto vec = vectorize::elements(ctx, axis, options.virtual_inputs);
auto preloads = preload::broadcasts(axis, options.virtual_inputs);
options.kernel_name = v.get("kernel", "kernel"); options.kernel_name = v.get("kernel", "kernel");
options.set_launch_params( options.set_launch_params(
v, v, compute_global_for(ctx, options.output.elements() / vec.size, 256));
compute_global_for(ctx,
options.output.elements() / vec.size,
oversubscribe_if(not preloads.is_preloading())));
auto src = interpolate_string(pointwise_kernel, auto src = interpolate_string(pointwise_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"lambda", v.at("lambda").to<std::string>()}, {"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(preloads, vec)}, {"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}}); {"preamble", v.get("preamble", std::string{})}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
......
...@@ -121,7 +121,7 @@ struct reduce_compiler : compiler<reduce_compiler> ...@@ -121,7 +121,7 @@ struct reduce_compiler : compiler<reduce_compiler>
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(options.virtual_inputs.back().lens()[faxis] == 1) if(options.virtual_inputs.back().lens()[faxis] == 1)
{ {
vec = vectorize::elements(faxis, options.virtual_inputs); vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
} }
auto relements = get_reduce_elements(options.virtual_inputs) / vec.size; auto relements = get_reduce_elements(options.virtual_inputs) / vec.size;
auto nelements = options.virtual_inputs.back().elements(); auto nelements = options.virtual_inputs.back().elements();
......
...@@ -69,7 +69,7 @@ struct softmax_compiler : compiler<softmax_compiler> ...@@ -69,7 +69,7 @@ struct softmax_compiler : compiler<softmax_compiler>
// Vectorize if the axis is a reduction axis // Vectorize if the axis is a reduction axis
if(faxis == axis) if(faxis == axis)
{ {
vec = vectorize::elements(faxis, inputs); vec = vectorize::elements(ctx, faxis, inputs);
} }
auto relements = inputs[0].lens()[axis] / vec.size; auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]); auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
......
...@@ -104,6 +104,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor) ...@@ -104,6 +104,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH(isnan, ::isnan) MIGRAPHX_DEVICE_MATH(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH(log, ::log) MIGRAPHX_DEVICE_MATH(log, ::log)
MIGRAPHX_DEVICE_MATH(pow, ::pow) MIGRAPHX_DEVICE_MATH(pow, ::pow)
MIGRAPHX_DEVICE_MATH(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH(round, ::round) MIGRAPHX_DEVICE_MATH(round, ::round)
MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt) MIGRAPHX_DEVICE_MATH(rsqrt, ::rsqrt)
MIGRAPHX_DEVICE_MATH(sin, ::sin) MIGRAPHX_DEVICE_MATH(sin, ::sin)
...@@ -111,6 +112,7 @@ MIGRAPHX_DEVICE_MATH(sinh, ::sinh) ...@@ -111,6 +112,7 @@ MIGRAPHX_DEVICE_MATH(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt) MIGRAPHX_DEVICE_MATH(sqrt, ::sqrt)
MIGRAPHX_DEVICE_MATH(tan, ::tan) MIGRAPHX_DEVICE_MATH(tan, ::tan)
MIGRAPHX_DEVICE_MATH(tanh, ::tanh) MIGRAPHX_DEVICE_MATH(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH(fmod, ::fmod)
// Float overloads // Float overloads
MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf) MIGRAPHX_DEVICE_MATH_FOR(float, acos, ::acosf)
...@@ -126,6 +128,7 @@ MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf) ...@@ -126,6 +128,7 @@ MIGRAPHX_DEVICE_MATH_FOR(float, sin, ::sinf)
MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf) MIGRAPHX_DEVICE_MATH_FOR(float, sinh, ::sinhf)
MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf) MIGRAPHX_DEVICE_MATH_FOR(float, tan, ::tanf)
MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf) MIGRAPHX_DEVICE_MATH_FOR(float, tanh, ::tanhf)
MIGRAPHX_DEVICE_MATH_FOR(float, fmod, ::fmodf)
// Builtin half functions // Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs) MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, abs, ::__habs)
...@@ -148,11 +151,13 @@ MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf) ...@@ -148,11 +151,13 @@ MIGRAPHX_DEVICE_MATH_HALF(erf, ::erf)
MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor) MIGRAPHX_DEVICE_MATH_HALF(floor, ::floor)
MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan) MIGRAPHX_DEVICE_MATH_HALF(isnan, ::isnan)
MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow) MIGRAPHX_DEVICE_MATH_HALF(pow, ::pow)
MIGRAPHX_DEVICE_MATH_HALF(remainder, ::remainder)
MIGRAPHX_DEVICE_MATH_HALF(round, ::round) MIGRAPHX_DEVICE_MATH_HALF(round, ::round)
MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin) MIGRAPHX_DEVICE_MATH_HALF(sin, ::sin)
MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh) MIGRAPHX_DEVICE_MATH_HALF(sinh, ::sinh)
MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan) MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh) MIGRAPHX_DEVICE_MATH_HALF(tanh, ::tanh)
MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// Map math functions to hip half2 functions // Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats // The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
...@@ -226,11 +231,13 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh) ...@@ -226,11 +231,13 @@ MIGRAPHX_DEVICE_MATH_VEC(cosh)
MIGRAPHX_DEVICE_MATH_VEC(erf) MIGRAPHX_DEVICE_MATH_VEC(erf)
MIGRAPHX_DEVICE_MATH_VEC(exp) MIGRAPHX_DEVICE_MATH_VEC(exp)
MIGRAPHX_DEVICE_MATH_VEC(floor) MIGRAPHX_DEVICE_MATH_VEC(floor)
MIGRAPHX_DEVICE_MATH_VEC(fmod)
MIGRAPHX_DEVICE_MATH_VEC(isnan) MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC(log) MIGRAPHX_DEVICE_MATH_VEC(log)
MIGRAPHX_DEVICE_MATH_VEC(max) MIGRAPHX_DEVICE_MATH_VEC(max)
MIGRAPHX_DEVICE_MATH_VEC(min) MIGRAPHX_DEVICE_MATH_VEC(min)
MIGRAPHX_DEVICE_MATH_VEC(pow) MIGRAPHX_DEVICE_MATH_VEC(pow)
MIGRAPHX_DEVICE_MATH_VEC(remainder)
MIGRAPHX_DEVICE_MATH_VEC(round) MIGRAPHX_DEVICE_MATH_VEC(round)
MIGRAPHX_DEVICE_MATH_VEC(rsqrt) MIGRAPHX_DEVICE_MATH_VEC(rsqrt)
MIGRAPHX_DEVICE_MATH_VEC(sin) MIGRAPHX_DEVICE_MATH_VEC(sin)
......
...@@ -81,26 +81,14 @@ struct miopen_apply ...@@ -81,26 +81,14 @@ struct miopen_apply
(void)i; (void)i;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
void init() void init()
{ {
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38 auto& ctx = get_context();
auto& ctx = get_context(); int8_x4_format = get_int8_x4_format(ctx);
const auto device_name = trim(split_string(get_device_name(), ':').front()); compute_fp32 = get_compute_fp32_flag();
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
offload_copy = (mod->name() == "main") ? pass->offload_copy : false; offload_copy = (mod->name() == "main") ? pass->offload_copy : false;
......
...@@ -21,7 +21,13 @@ ...@@ -21,7 +21,13 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <unordered_set>
#include <migraphx/ranges.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/rocblas.hpp> #include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -41,6 +47,33 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s) ...@@ -41,6 +47,33 @@ rocblas_handle_ptr create_rocblas_handle_ptr(hipStream_t s)
return rb; return rb;
} }
const std::unordered_set<std::string>& get_rocblas_fp32_archs()
{
static std::unordered_set<std::string> supported_archs{"gfx908", "gfx90a"};
return supported_archs;
}
bool get_compute_fp32_flag()
{
bool compute_fp32 = false;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
#endif
return compute_fp32;
}
bool get_int8_x4_format(context& ctx)
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
}
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -307,12 +307,14 @@ TEST_CASE(compile_math) ...@@ -307,12 +307,14 @@ TEST_CASE(compile_math)
"erf(x)", "erf(x)",
"exp(x)", "exp(x)",
"floor(x)", "floor(x)",
"fmod(x, x)",
"isnan(x)", "isnan(x)",
"log(x)", "log(x)",
"max(x, x)", "max(x, x)",
"min(x, x)", "min(x, x)",
"pow(x, 0)", "pow(x, 0)",
"pow(x, x)", "pow(x, x)",
"remainder(x,x)",
"round(x)", "round(x)",
"rsqrt(x)", "rsqrt(x)",
"sin(x)", "sin(x)",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment