Commit b2f12dae authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/develop' into ck-gemm-int8

parents 05122cc7 ce2adafd
...@@ -8,6 +8,9 @@ on: ...@@ -8,6 +8,9 @@ on:
- master - master
- 'release/**' - 'release/**'
env:
DOCKER_USER: ${{secrets.DOCKERHUB_USERID}}
DOCKER_TOKEN: ${{secrets.DOCKERHUB_TOKEN}}
jobs: jobs:
cancel: cancel:
...@@ -17,23 +20,93 @@ jobs: ...@@ -17,23 +20,93 @@ jobs:
uses: styfle/cancel-workflow-action@0.11.0 uses: styfle/cancel-workflow-action@0.11.0
with: with:
access_token: ${{ github.token }} access_token: ${{ github.token }}
check_image:
name: Check if image exists in registry
runs-on: ubuntu-latest
outputs:
imageexists: ${{ steps.check_image.outputs.imageexists }}
imagetag: ${{ steps.image_hash.outputs.imagetag }}
imageexists_sles: ${{ steps.check_image.outputs.imageexists_sles }}
imagetag_sles: ${{ steps.image_hash.outputs.imagetag_sles }}
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Create Image Tag
id: image_hash
run: |
echo "imagetag=rocm/migraphx-private:hip-clang-${{hashFiles('**/hip-clang.docker', '**/*requirements.txt', '**/install_prereqs.sh', '**/rbuild.ini')}}" >> $GITHUB_OUTPUT
echo "imagetag_sles=rocm/migraphx-sles-private:hip-clang-${{hashFiles('**/tools/docker/sles.docker', '**/*requirements.txt', '**/install_prereqs.sh', '**/rbuild.ini')}}" >> $GITHUB_OUTPUT
- name: Check if image is built already
id: check_image
env:
DOCKERIMAGE: ${{ steps.image_hash.outputs.imagetag }}
DOCKERIMAGE_SLES: ${{ steps.image_hash.outputs.imagetag_sles }}
run: |
echo $DOCKER_TOKEN | docker login -u $DOCKER_USER --password-stdin
if [[ "$(docker manifest inspect $DOCKERIMAGE 2> /dev/null)" != "" ]]; then
echo "imageexists=true" >> $GITHUB_OUTPUT
echo "Image already exists, skip building available"
else
echo "imageexists=false" >> $GITHUB_OUTPUT
echo "Tag does not exist, build and publishing required"
fi
if [[ "$(docker manifest inspect $DOCKERIMAGE_SLES 2> /dev/null)" != "" ]]; then
echo "imageexists_sles=true" >> $GITHUB_OUTPUT
echo "SLES Image already exists, skip building available"
else
echo "imageexists_sles=false" >> $GITHUB_OUTPUT
echo "SLES Tag does not exist, build and publishing required"
fi
build_image:
name: Build image
runs-on: ROCM-Ubuntu
needs: check_image
if: ${{ needs.check_image.outputs.imageexists != 'true' }}
steps:
- uses: actions/checkout@v3
- name: Build and publish
env:
DOCKERIMAGE: ${{ needs.check_image.outputs.imagetag }}
run: |
echo $DOCKER_TOKEN | docker login -u $DOCKER_USER --password-stdin
docker build . --file hip-clang.docker --tag $DOCKERIMAGE;
docker push $DOCKERIMAGE;
build_SLES_image:
name: Build SLES image
runs-on: ROCM-Ubuntu
needs: check_image
if: ${{ needs.check_image.outputs.imageexists_sles != 'true' }}
steps:
- uses: actions/checkout@v3
- name: Build and publish SLES
env:
DOCKERIMAGE_SLES: ${{ needs.check_image.outputs.imagetag_sles }}
run: |
echo $DOCKER_TOKEN | docker login -u $DOCKER_USER --password-stdin
docker build . --file tools/docker/sles.docker --tag $DOCKERIMAGE_SLES;
docker push $DOCKERIMAGE_SLES;
tidy: tidy:
runs-on: ROCM-Ubuntu runs-on: ROCM-Ubuntu
needs: [ build_image, check_image ]
env:
DOCKERIMAGE: ${{ needs.check_image.outputs.imagetag }}
if: ${{ !cancelled() && (needs.build_image.result == 'success' || needs.build_image.result == 'skipped') }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
# In this step, this action saves a list of existing images,
# the cache is created without them in the post run.
# It also restores the cache if it exists.
- name: Docker layer cache
uses: jpribyl/action-docker-layer-caching@v0.1.1
with:
key: docker-layer-caching-migraphx-${{hashFiles('hip-clang.docker', '**/*requirements.txt', '**/install_prereqs.sh', 'rbuild.ini')}}
restore-keys:
docker-layer-caching-migraphx-
# Ignore the failure of a step and avoid terminating the job.
continue-on-error: true
- name: Restore cache files for tidy - name: Restore cache files for tidy
uses: actions/cache/restore@v3 uses: actions/cache/restore@v3
id: tidy_restore id: tidy_restore
...@@ -41,13 +114,13 @@ jobs: ...@@ -41,13 +114,13 @@ jobs:
path: tidy-cache path: tidy-cache
key: tidy-cache-${{ github.ref }} key: tidy-cache-${{ github.ref }}
restore-keys: tidy-cache- restore-keys: tidy-cache-
- name: Build the Docker image - name: Docker Login
run: | run: |
docker build . --file hip-clang.docker --tag migraphx echo $DOCKER_TOKEN | docker login -u $DOCKER_USER --password-stdin
- name: Clang tidy - name: Clang Tidy
shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data migraphx bash < {0}" shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data $DOCKERIMAGE bash < {0}"
run: | run: |
mkdir build mkdir build
cd build cd build
...@@ -84,21 +157,14 @@ jobs: ...@@ -84,21 +157,14 @@ jobs:
cppcheck: cppcheck:
runs-on: ROCM-Ubuntu runs-on: ROCM-Ubuntu
needs: [ build_image, check_image ]
env:
DOCKERIMAGE: ${{ needs.check_image.outputs.imagetag }}
if: ${{ !cancelled() && (needs.build_image.result == 'success' || needs.build_image.result == 'skipped') }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
# In this step, this action saves a list of existing images,
# the cache is created without them in the post run.
# It also restores the cache if it exists.
- name: Docker layer cache
uses: jpribyl/action-docker-layer-caching@v0.1.1
with:
key: docker-layer-caching-migraphx-${{hashFiles('hip-clang.docker', '**/*requirements.txt', '**/install_prereqs.sh', 'rbuild.ini')}}
restore-keys:
docker-layer-caching-migraphx-
# Ignore the failure of a step and avoid terminating the job.
continue-on-error: true
- name: Restore cache files for cppcheck - name: Restore cache files for cppcheck
id: cppcheck_restore id: cppcheck_restore
uses: actions/cache/restore@v3 uses: actions/cache/restore@v3
...@@ -107,11 +173,12 @@ jobs: ...@@ -107,11 +173,12 @@ jobs:
key: cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ github.ref }} key: cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-${{ github.ref }}
restore-keys: cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}- restore-keys: cppcheck-cache-${{ hashFiles('cppcheck.rules', 'CMakeLists.txt') }}-
- name: Build the Docker image - name: Docker Login
run: docker build . --file hip-clang.docker --tag migraphx run: |
echo $DOCKER_TOKEN | docker login -u $DOCKER_USER --password-stdin
- name: Cppcheck - name: Cppcheck
shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data migraphx bash < {0}" shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data $DOCKERIMAGE bash < {0}"
run: | run: |
mkdir build mkdir build
cd build cd build
...@@ -142,29 +209,23 @@ jobs: ...@@ -142,29 +209,23 @@ jobs:
format: format:
runs-on: ROCM-Ubuntu runs-on: ubuntu-latest
needs: [ build_image, check_image ]
env:
DOCKERIMAGE: ${{ needs.check_image.outputs.imagetag }}
if: ${{ !cancelled() && (needs.build_image.result == 'success' || needs.build_image.result == 'skipped') }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
with: with:
fetch-depth: 0 fetch-depth: 0
# In this step, this action saves a list of existing images, - name: Docker Login
# the cache is created without them in the post run. run: |
# It also restores the cache if it exists. echo $DOCKER_TOKEN | docker login -u $DOCKER_USER --password-stdin
- name: Docker layer cache
uses: jpribyl/action-docker-layer-caching@v0.1.1
with:
key: docker-layer-caching-migraphx-${{hashFiles('hip-clang.docker', '**/*requirements.txt', '**/install_prereqs.sh', 'rbuild.ini')}}
restore-keys:
docker-layer-caching-migraphx-
# Ignore the failure of a step and avoid terminating the job.
continue-on-error: true
- name: Build the Docker image
run: docker build . --file hip-clang.docker --tag migraphx
- name: Check formatting - name: Check formatting
shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data migraphx bash < {0}" shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data $DOCKERIMAGE bash < {0}"
run: | run: |
set -e set -e
git config --global --add safe.directory /data git config --global --add safe.directory /data
...@@ -172,26 +233,16 @@ jobs: ...@@ -172,26 +233,16 @@ jobs:
sles: sles:
runs-on: ROCM-Ubuntu runs-on: ROCM-Ubuntu
needs: [ build_SLES_image, check_image ]
env:
DOCKERIMAGE_SLES: ${{ needs.check_image.outputs.imagetag_sles }}
if: ${{ !cancelled() && (needs.build_SLES_image.result == 'success' || needs.build_SLES_image.result == 'skipped') }}
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
with: with:
fetch-depth: 0 fetch-depth: 0
# In this step, this action saves a list of existing images,
# the cache is created without them in the post run.
# It also restores the cache if it exists.
- name: Docker layer cache
uses: jpribyl/action-docker-layer-caching@v0.1.1
with:
key: docker-layer-caching-migraphx-sles-${{hashFiles('hip-clang.docker', '**/*requirements.txt', '**/install_prereqs.sh', 'rbuild.ini')}}
restore-keys:
docker-layer-caching-migraphx-sles-
# Ignore the failure of a step and avoid terminating the job.
continue-on-error: true
- name: Build the Docker image
run: docker build . --file tools/docker/sles.docker --tag migraphx-sles
- name: Restore cache files for ccache - name: Restore cache files for ccache
uses: actions/cache/restore@v3 uses: actions/cache/restore@v3
id: ccache_restore id: ccache_restore
...@@ -200,8 +251,12 @@ jobs: ...@@ -200,8 +251,12 @@ jobs:
key: ccache-sles-${{ github.ref }} key: ccache-sles-${{ github.ref }}
restore-keys: ccache-sles- restore-keys: ccache-sles-
- name: Docker Login
run: |
echo $DOCKER_TOKEN | docker login -u $DOCKER_USER --password-stdin
- name: Build migraphx - name: Build migraphx
shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data migraphx-sles bash < {0}" shell: bash -c "docker run -i -v=$GITHUB_WORKSPACE:/data -w /data $DOCKERIMAGE_SLES bash < {0}"
run: | run: |
set -e set -e
export CCACHE_COMPRESSLEVEL=10 export CCACHE_COMPRESSLEVEL=10
......
...@@ -86,7 +86,7 @@ function(py_add_module NAME) ...@@ -86,7 +86,7 @@ function(py_add_module NAME)
) )
endfunction() endfunction()
set(PYTHON_SEARCH_VERSIONS 2.7 3.5 3.6 3.7 3.8 3.9 3.10) set(PYTHON_SEARCH_VERSIONS 3.5 3.6 3.7 3.8 3.9 3.10)
set(PYTHON_DISABLE_VERSIONS "" CACHE STRING "") set(PYTHON_DISABLE_VERSIONS "" CACHE STRING "")
foreach(PYTHON_DISABLE_VERSION ${PYTHON_DISABLE_VERSIONS}) foreach(PYTHON_DISABLE_VERSION ${PYTHON_DISABLE_VERSIONS})
list(REMOVE_ITEM PYTHON_SEARCH_VERSIONS ${PYTHON_DISABLE_VERSION}) list(REMOVE_ITEM PYTHON_SEARCH_VERSIONS ${PYTHON_DISABLE_VERSION})
......
...@@ -142,6 +142,7 @@ register_migraphx_ops( ...@@ -142,6 +142,7 @@ register_migraphx_ops(
equal equal
erf erf
exp exp
fill
flatten flatten
floor floor
fmod fmod
......
...@@ -153,7 +153,7 @@ struct check_shapes ...@@ -153,7 +153,7 @@ struct check_shapes
{ {
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() != n) if(begin->ndim() != n)
MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported"); MIGRAPHX_THROW(prefix() + "Only " + std::to_string(n) + "d supported");
} }
return *this; return *this;
...@@ -168,7 +168,7 @@ struct check_shapes ...@@ -168,7 +168,7 @@ struct check_shapes
{ {
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() > n) if(begin->ndim() > n)
MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at most " + std::to_string(n) +
" dimensions"); " dimensions");
} }
...@@ -184,7 +184,7 @@ struct check_shapes ...@@ -184,7 +184,7 @@ struct check_shapes
{ {
if(begin != end) if(begin != end)
{ {
if(begin->max_lens().size() < n) if(begin->ndim() < n)
MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) + MIGRAPHX_THROW(prefix() + "Shape must have at least " + std::to_string(n) +
" dimensions"); " dimensions");
} }
...@@ -254,6 +254,16 @@ struct check_shapes ...@@ -254,6 +254,16 @@ struct check_shapes
return *this; return *this;
} }
/*!
* Check all shapes are scalar.
*/
const check_shapes& scalar() const
{
if(not this->all_of([](const shape& s) { return s.scalar(); }))
MIGRAPHX_THROW(prefix() + "Shapes are not a scalar");
return *this;
}
/*! /*!
* Check all shapes are standard or scalar. * Check all shapes are standard or scalar.
*/ */
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_FILL_HPP
#define MIGRAPHX_GUARD_OPERATORS_FILL_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/dyn_output.hpp>
#include <migraphx/par_for.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
/**
* fill(default_value, output_buffer)
* Fill an output buffer with the given default_value.
* Note that if the default_value is a literal and the output_buffer
* has a static shape this operator can be replaced with a literal.
*/
struct fill
{
std::string name() const { return "fill"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(2).same_type();
if(inputs.at(0).dynamic() or inputs.at(0).elements() != 1)
{
MIGRAPHX_THROW("FILL: default_value is dynamic or more than one element");
}
return inputs.back();
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
visit_all(args[0], args[1])([&](auto value, auto output) {
par_for(dyn_out.computed_shape.elements(), [&](auto i) { output[i] = value.front(); });
});
return args[1];
}
std::ptrdiff_t output_alias(const std::vector<shape>&) const { return 1; }
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -124,7 +124,7 @@ struct roialign ...@@ -124,7 +124,7 @@ struct roialign
{ {
xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] + xy[ii] = roi_start[ii] + p[ii] * bin_size[ii] +
(i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii]; (i[ii] + .5f) * bin_size[ii] / bin_grid_size[ii];
xy[ii] = (coord_trans_mode == "output_half_pixel") ? (xy[ii] - 0.5f) : xy[ii]; xy[ii] = (coord_trans_mode == "half_pixel") ? (xy[ii] - 0.5f) : xy[ii];
if(xy[ii] < -1.0 or xy[ii] > dims[ii]) if(xy[ii] < -1.0 or xy[ii] > dims[ii])
{ {
results[index] = pos_weight{}; results[index] = pos_weight{};
......
...@@ -55,6 +55,7 @@ ...@@ -55,6 +55,7 @@
#include <migraphx/op/equal.hpp> #include <migraphx/op/equal.hpp>
#include <migraphx/op/erf.hpp> #include <migraphx/op/erf.hpp>
#include <migraphx/op/exp.hpp> #include <migraphx/op/exp.hpp>
#include <migraphx/op/fill.hpp>
#include <migraphx/op/flatten.hpp> #include <migraphx/op/flatten.hpp>
#include <migraphx/op/floor.hpp> #include <migraphx/op/floor.hpp>
#include <migraphx/op/fmod.hpp> #include <migraphx/op/fmod.hpp>
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant> ...@@ -39,16 +40,38 @@ struct parse_constant : op_parser<parse_constant>
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& /*args*/) const const std::vector<instruction_ref>& /*args*/) const
{ {
literal v = parser.parse_value(info.attributes.at("value")); static const std::vector<std::string> attributes = {
"value", "value_float", "value_floats", "value_int", "value_ints"};
std::vector<std::string> present_attributes;
std::copy_if(attributes.begin(),
attributes.end(),
std::back_inserter(present_attributes),
[&](const std::string& a) { return contains(info.attributes, a); });
if(present_attributes.empty())
{
MIGRAPHX_THROW("Constant node does not contain any supported attribute");
}
if(present_attributes.size() > 1)
{
MIGRAPHX_THROW("Constant contains multiple attributes: " +
join_strings(std::move(present_attributes), ", "));
}
// cppcheck-suppress accessMoved
auto&& attr = info.attributes[present_attributes[0]];
literal v = parser.parse_value(attr);
// return empty literal // return empty literal
if(v.get_shape().elements() == 0) if(v.get_shape().elements() == 0)
{ {
return info.add_literal(literal{v.get_shape().type()}); return info.add_literal(literal{v.get_shape().type()});
} }
auto dim_size = info.attributes.at("value").t().dims_size();
// if dim_size is 0, it is a scalar // if dim_size is 0, it is a scalar
if(dim_size == 0) if(attr.has_t() and attr.t().dims_size() == 0)
{ {
migraphx::shape scalar_shape{v.get_shape().type()}; migraphx::shape scalar_shape{v.get_shape().type()};
return info.add_literal(migraphx::literal{scalar_shape, v.data()}); return info.add_literal(migraphx::literal{scalar_shape, v.data()});
......
...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign> ...@@ -37,15 +37,18 @@ struct parse_roialign : op_parser<parse_roialign>
std::vector<op_desc> operators() const { return {{"RoiAlign"}}; } std::vector<op_desc> operators() const { return {{"RoiAlign"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
std::string coord_trans_mode = "half_pixel"; std::string coord_trans_mode =
if(contains(info.attributes, "coordinate_transformation_mode")) parser.opset_version >= 16 ? "half_pixel" : "output_half_pixel";
if(const auto* a = "coordinate_transformation_mode"; contains(info.attributes, a))
{ {
coord_trans_mode = info.attributes.at("coordinate_transformation_mode").s(); coord_trans_mode = info.attributes.at(a).s();
} }
if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode)) if(not contains({"half_pixel", "output_half_pixel"}, coord_trans_mode))
{ {
MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode + MIGRAPHX_THROW("coordinate_transformation_mode \"" + coord_trans_mode +
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const ...@@ -36,8 +36,12 @@ void optimize_module::apply(module_pass_manager& mpm) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 2; i++)
{ {
mpm.run_pass(simplify_reshapes{}); // loop to further optimize after initial transformations
mpm.run_pass(simplify_algebra{}); for(int j = 0; j < 2; j++)
{
mpm.run_pass(simplify_reshapes{});
mpm.run_pass(simplify_algebra{});
}
mpm.run_pass(eliminate_common_subexpression{}); mpm.run_pass(eliminate_common_subexpression{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
mpm.run_pass(propagate_constant{}); mpm.run_pass(propagate_constant{});
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -35,10 +35,10 @@ inline namespace MIGRAPHX_INLINE_NS {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_PROPAGATE_CONSTANT)
bool skip_propogate(instruction_ref ins) bool skip_propagate(instruction_ref ins)
{ {
if(ins->name() == "contiguous") if(ins->name() == "contiguous")
return skip_propogate(ins->inputs().front()); return skip_propagate(ins->inputs().front());
auto&& s = ins->get_shape(); auto&& s = ins->get_shape();
if(s.broadcasted() and not s.scalar()) if(s.broadcasted() and not s.scalar())
return true; return true;
...@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins) ...@@ -47,7 +47,7 @@ bool skip_propogate(instruction_ref ins)
return false; return false;
} }
bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propogate(ins); } bool is_const_ins(instruction_ref ins) { return ins->can_eval() and not skip_propagate(ins); }
void propagate_constant::apply(module& m) const void propagate_constant::apply(module& m) const
{ {
......
...@@ -1325,48 +1325,59 @@ struct find_split_reshape ...@@ -1325,48 +1325,59 @@ struct find_split_reshape
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto slc = r.instructions["slice"]; auto slc = r.instructions["slice"];
auto rsp = r.instructions["reshape"]; auto rsp = r.instructions["reshape"];
auto input = slc->inputs().front();
// Only apply simplification when slices are on a single axis
auto axes = any_cast<op::slice>(slc->get_operator()).axes;
if(axes.size() > 1)
{
return;
}
auto input = slc->inputs().front();
auto split_outputs = get_splits(input); auto split_outputs = get_splits(input);
if(split_outputs.empty()) if(split_outputs.empty())
{ {
return; return;
} }
// Only want to apply this optimization if each split output is followed by // Find all the reshapes (similar to rsp) that can be simplified
// a contiguous op and a reshape std::vector<instruction_ref> conts;
if(std::any_of(split_outputs.begin(), split_outputs.end(), [](auto i) { std::vector<instruction_ref> vec_rsp;
if(i->outputs().size() == 1)
{ // Iterate through slice and contiguous outputs to allow simplifications when
auto cont = i->outputs().front(); // slice is followed by multiple reshapes
return cont->outputs().size() != 1; for(auto& i : split_outputs)
}
return false;
}))
{ {
return; std::copy_if(i->outputs().begin(),
i->outputs().end(),
std::back_inserter(conts),
[](auto j) { return j->name() == "contiguous"; });
} }
std::vector<instruction_ref> vec_rsp(split_outputs.size()); for(auto& i : conts)
std::transform(split_outputs.begin(), split_outputs.end(), vec_rsp.begin(), [](auto i) { {
auto cont = i->outputs().front(); std::copy_if(i->outputs().begin(),
return cont->outputs().front(); i->outputs().end(),
}); std::back_inserter(vec_rsp),
[&](auto j) { return j->get_operator() == rsp->get_operator(); });
}
// all outputs are reshape and of the same shape // No simplification needed if there is only one slice -> cont -> reshape
auto dims = any_cast<op::reshape>(rsp->get_operator()).dims; if(vec_rsp.size() <= 1)
if(not same_ops(vec_rsp))
{ {
return; return;
} }
// ensure reshape happens after the axis dimension // ensure reshape happens after the axis dimension
auto axis = any_cast<op::slice>(slc->get_operator()).axes[0]; auto axis = axes[0];
auto slc_lens = slc->get_shape().lens(); auto slc_lens = slc->get_shape().lens();
auto slc_dim_size = std::accumulate( auto slc_dim_size = std::accumulate(
slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>()); slc_lens.begin() + axis, slc_lens.end(), 1, std::multiplies<std::size_t>());
auto input_lens = input->get_shape().lens();
auto input_size = input->get_shape().elements();
auto slc_axis_len = input_lens[axis];
// search the reshape output (standard shape) to decide which axis are // search the reshape output (standard shape) to decide which axis are
// in its output corresponding to the slc_dim_size // in its output corresponding to the slc_dim_size
...@@ -1393,16 +1404,67 @@ struct find_split_reshape ...@@ -1393,16 +1404,67 @@ struct find_split_reshape
{ {
rsp_axis = std::distance(rsp_strides.begin(), ait); rsp_axis = std::distance(rsp_strides.begin(), ait);
} }
// calculate reshape output shape
std::vector<int64_t> vec_dims(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), vec_dims.begin(), [&](auto is) { // Calculate reshape output shape
return is->get_shape().lens()[rsp_axis]; // Need to find a reshape such that data represented by instructions in vec_rsp can be
}); // written as slices of this new reshape. This is done by holding all the dims constant in
// rsp_lens to compute the required dim for rsp_axis (axis that will be sliced)
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1, 2, 4}, rsp_fixed_size = 2*1*2*4 = 16
// rsp_axis_len = 2*12*4 / 16 = 6
// rsp_out_lens (final) = {2, 6, 2, 4}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// rsp_axis = 1, rsp_out_lens (initial) = {2, 1}, rsp_fixed_size = 2*1 = 2
// rsp_axis_len = 2*12*4 / 2 = 48
// rsp_out_lens (final) = {2, 48}
std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end()); std::vector<int64_t> rsp_out_lens(rsp_lens.begin(), rsp_lens.end());
rsp_out_lens[rsp_axis] = 1;
auto rsp_fixed_size = std::accumulate(
rsp_out_lens.begin(), rsp_out_lens.end(), 1, std::multiplies<std::size_t>());
rsp_out_lens[rsp_axis] = std::accumulate(vec_dims.begin(), vec_dims.end(), std::int64_t{0}); // cannot create a valid reshape for simplification
if(input_size % rsp_fixed_size != 0)
{
return;
}
auto rsp_axis_len = input_size / rsp_fixed_size;
rsp_out_lens[rsp_axis] = rsp_axis_len;
// Calculate new slice start and end indices. Indices are scaled using the new reshape axis
// and the original slice axis. See examples:
// ex 1: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 2, 2, 4}, {2, 2, 2, 4}, {2, 2, 2, 4}
// slc_axis_len = 12, rsp_axis_len = 6
// New Starts: {0*6/12, 4*6/12, 8*6/12} = {0, 2, 4}
// New Ends: {4*6/12, 8*6/12, 12*6/12} = {2, 4, 6}
// ex 2: Input Shape: {2, 12, 4}, Slice Axis: 1, Slices are: (0:4), (4:8), (8:12),
// Reshape Outputs: {2, 16}, {2, 16}, {2, 16}
// slc_axis_len = 12, rsp_axis_len = 48
// New Starts: {0*48/12, 4*48/12, 8*48/12} = { 0, 16, 32}
// New Ends: {4*48/12, 8*48/12, 12*48/12} = {16, 32, 48}
std::vector<int64_t> new_starts(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_starts.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).starts[0] * rsp_axis_len /
slc_axis_len;
});
std::vector<int64_t> new_ends(vec_rsp.size());
std::transform(vec_rsp.begin(), vec_rsp.end(), new_ends.begin(), [&](auto is) {
auto cont = is->inputs().front();
auto og_slc = cont->inputs().front();
return any_cast<op::slice>(og_slc->get_operator()).ends[0] * rsp_axis_len /
slc_axis_len;
});
// insert the reshape instruction and add contiguous if needed // insert the reshape instruction and add contiguous if needed
if(not input->get_shape().standard()) if(not input->get_shape().standard())
...@@ -1413,16 +1475,14 @@ struct find_split_reshape ...@@ -1413,16 +1475,14 @@ struct find_split_reshape
std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input); std::next(input), make_op("reshape", {{"dims", rsp_out_lens}}), input);
// replace the original reshape with slice // replace the original reshape with slice
int64_t start = 0;
for(std::size_t i = 0; i < vec_rsp.size(); ++i) for(std::size_t i = 0; i < vec_rsp.size(); ++i)
{ {
m.replace_instruction( m.replace_instruction(
vec_rsp[i], vec_rsp[i],
make_op( make_op(
"slice", "slice",
{{"axes", {rsp_axis}}, {"starts", {start}}, {"ends", {start + vec_dims[i]}}}), {{"axes", {rsp_axis}}, {"starts", {new_starts[i]}}, {"ends", {new_ends[i]}}}),
rsp_ins); rsp_ins);
start += vec_dims[i];
} }
} }
}; };
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary ...@@ -627,6 +627,30 @@ struct find_transpose_contiguous_reshaper_unary
} }
}; };
struct find_broadcast_transpose
{
auto matcher() const
{
return match::name("transpose")(
match::arg(0)(match::name("multibroadcast").bind("bcast_ins")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins_lens = ins->get_shape().lens();
auto bcast_ins = r.instructions["bcast_ins"];
auto input = bcast_ins->inputs().front();
// for now, focusing on scalar transformation
if(not input->get_shape().scalar())
return;
auto new_mbcast = m.insert_instruction(
bcast_ins, make_op("multibroadcast", {{"out_lens", ins_lens}}), input);
m.replace_instruction(ins, new_mbcast);
}
};
struct find_slice_transpose struct find_slice_transpose
{ {
auto matcher() const auto matcher() const
...@@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -799,6 +823,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_slice{}, find_transpose_slice{},
find_broadcast_transpose{},
find_slice_transpose{}, find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
......
...@@ -50,6 +50,7 @@ file(GLOB KERNEL_FILES CONFIGURE_DEPENDS ...@@ -50,6 +50,7 @@ file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp) file(GLOB DEVICE_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/device/*.cpp)
add_library(migraphx_device ${DEVICE_GPU_SRCS}) add_library(migraphx_device ${DEVICE_GPU_SRCS})
...@@ -69,6 +70,7 @@ rocm_clang_tidy_check(migraphx_device) ...@@ -69,6 +70,7 @@ rocm_clang_tidy_check(migraphx_device)
target_link_libraries(migraphx_device PUBLIC migraphx) target_link_libraries(migraphx_device PUBLIC migraphx)
target_link_libraries(migraphx_device PRIVATE compile_for_gpu) target_link_libraries(migraphx_device PRIVATE compile_for_gpu)
target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>) target_include_directories(migraphx_device PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_BINAR_DIR}/include>)
target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>) target_include_directories(migraphx_device PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/device/include>)
target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes) target_compile_options(migraphx_device PRIVATE -Wno-ignored-attributes)
migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device) migraphx_generate_export_header(migraphx_device DIRECTORY migraphx/gpu/device)
......
...@@ -26,7 +26,9 @@ ...@@ -26,7 +26,9 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/gpu/device/types.hpp> #include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/targets.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -84,8 +86,15 @@ inline auto launch(hipStream_t stream, index_int global, index_int local) ...@@ -84,8 +86,15 @@ inline auto launch(hipStream_t stream, index_int global, index_int local)
hipError_t kernel_launch_status = hipGetLastError(); hipError_t kernel_launch_status = hipGetLastError();
if(kernel_launch_status != hipSuccess) if(kernel_launch_status != hipSuccess)
{ {
MIGRAPHX_THROW("MIGraphX device kernel failed to launch with error: " + std::string message = hipGetErrorString(kernel_launch_status);
std::string(hipGetErrorString(kernel_launch_status))); if(not contains(get_targets(), get_device_name()))
{
message += ". Trying to run a kernel for " + get_device_name() +
" but MIGraphX was built for targets " + get_targets_as_string() +
". Please rebuild MIGraphX with -DGPU_TARGETS='" + get_device_name() +
"'.";
}
MIGRAPHX_THROW("MIGraphX device kernel failed to launch with error: " + message);
} }
}; };
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/device/targets.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/errors.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
static std::vector<std::string> parse_targets() { return split_string(MIGRAPHX_GPU_TARGETS, ';'); }
const std::vector<std::string>& get_targets()
{
static auto result = parse_targets();
return result;
}
std::string get_targets_as_string() { return join_strings(get_targets(), ", "); }
static int get_device_id()
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
MIGRAPHX_THROW("No device");
return device;
}
std::string get_device_name()
{
hipDeviceProp_t props{};
auto status = hipGetDeviceProperties(&props, get_device_id());
if(status != hipSuccess)
MIGRAPHX_THROW("Failed to get device properties");
return props.gcnArchName;
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp>
#include <string>
#include <vector>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
const std::vector<std::string>& get_targets();
std::string get_targets_as_string();
std::string get_device_name();
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
...@@ -103,7 +103,10 @@ struct mlir_op ...@@ -103,7 +103,10 @@ struct mlir_op
} }
if(ins->name() == "@return") if(ins->name() == "@return")
{ {
return ins_shapes[ins->inputs().at(0)].with_type(type); auto s = ins_shapes[ins->inputs().at(0)].with_type(type);
if(not s.standard())
MIGRAPHX_THROW("MLIR doesnt support non-standard output");
return s;
} }
std::vector<shape> input_shapes; std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size()); input_shapes.resize(ins->inputs().size());
...@@ -299,10 +302,8 @@ struct find_mlir_fused_ops ...@@ -299,10 +302,8 @@ struct find_mlir_fused_ops
} }
}; };
struct find_mlir_standalone_convolution_op struct find_mlir_standalone_op
{ {
auto matcher() const { return match::name("convolution"); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto conv_based_op = r.result; auto conv_based_op = r.result;
...@@ -324,6 +325,16 @@ struct find_mlir_standalone_convolution_op ...@@ -324,6 +325,16 @@ struct find_mlir_standalone_convolution_op
} }
}; };
struct find_mlir_standalone_convolution_op : find_mlir_standalone_op
{
auto matcher() const { return match::name("convolution"); }
};
struct find_mlir_standalone_dot_op : find_mlir_standalone_op
{
auto matcher() const { return match::name("dot"); }
};
/** /**
* @brief Declares a new MIGraphX environment variable which forces to generate * @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations. * only specific MLIR operations.
...@@ -331,7 +342,7 @@ struct find_mlir_standalone_convolution_op ...@@ -331,7 +342,7 @@ struct find_mlir_standalone_convolution_op
* The variable, if defined, forces MIGraphX to use only specific operations * The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts * with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following * a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution". If the variable is not defined MIGraphX * operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is * will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers. * intended to be primarily used by rocMLIR developers.
*/ */
...@@ -346,31 +357,33 @@ bool is_requested(std::string_view option) ...@@ -346,31 +357,33 @@ bool is_requested(std::string_view option)
return contains(options, option); return contains(options, option);
} }
bool is_fusion_enabled() bool is_enabled(std::string_view op_name, context* ctx)
{ {
if(is_self_decide()) if(is_self_decide())
{ {
return true; if(op_name == "fused")
}
return is_requested("fused");
}
bool is_standalone_convs_enabled(context* ctx)
{
if(is_self_decide())
{
if(ctx == nullptr)
{ {
return false; return true;
}
else if(op_name == "convolution")
{
if(ctx == nullptr)
{
return false;
}
else
{
const auto& device = ctx->get_current_device();
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
}
} }
else else
{ {
const auto& device = ctx->get_current_device(); return false;
const std::string navi_family{"gfx110"};
return starts_with(device.get_gfx_name(), navi_family);
} }
} }
return is_requested("convolution"); return is_requested(op_name);
} }
} // namespace } // namespace
...@@ -379,21 +392,25 @@ bool is_standalone_convs_enabled(context* ctx) ...@@ -379,21 +392,25 @@ bool is_standalone_convs_enabled(context* ctx)
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
if(is_fusion_enabled()) if(is_enabled("fused", this->ctx))
{ {
match::find_matches(mpm, find_mlir_fused_ops{}); match::find_matches(mpm, find_mlir_fused_ops{});
} }
if(is_standalone_convs_enabled(this->ctx)) if(is_enabled("convolution", this->ctx))
{ {
match::find_matches(mpm, find_mlir_standalone_convolution_op{}); match::find_matches(mpm, find_mlir_standalone_convolution_op{});
} }
if(is_enabled("dot", this->ctx))
{
match::find_matches(mpm, find_mlir_standalone_dot_op{});
}
#else #else
(void)mpm; (void)mpm;
#endif #endif
} }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m, ...@@ -49,7 +49,8 @@ MIGRAPHX_GPU_EXPORT instruction_ref insert_mlir(module& m,
MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx_ctx, MIGRAPHX_GPU_EXPORT tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
module m, module m,
const std::vector<shape>& inputs); const std::vector<shape>& inputs,
bool exhaustive);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler> ...@@ -57,11 +57,9 @@ struct mlir_compiler : compiler<mlir_compiler>
const operation&, const operation&,
bool exhaustive) const bool exhaustive) const
{ {
if(not exhaustive)
return nullopt;
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
auto* smod = ins->module_inputs().front(); auto* smod = ins->module_inputs().front();
return get_tuning_config_mlir(ctx, *smod, shapes); return get_tuning_config_mlir(ctx, *smod, shapes, exhaustive);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment