Commit 2e79bb1b authored by Alan Turner's avatar Alan Turner
Browse files

remove debug prints from fuse_ck

parent f83139de
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
namespace gpu {
struct ck_gemm
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm"; }
void check_gemm_shape(const shape& s) const
{
if(contains(s.lens(), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.not_broadcasted();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
auto a = inputs[n - 2];
auto b = inputs[n - 1];
check_gemm_shape(a);
check_gemm_shape(b);
return op.compute_shape({a, b});
}
};
MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_add_add_gelu
{
operation op = make_op("dot");
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"));
}
std::string name() const { return "gpu::ck_gemm_add_add_gelu"; }
void check_gemm_shape(const shape& s) const
{
if(contains(s.lens(), 1))
MIGRAPHX_THROW("Invalid shape for ck_gemm");
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.not_broadcasted();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
auto a = inputs[n - 2];
auto b = inputs[n - 1];
check_gemm_shape(a);
check_gemm_shape(b);
return op.compute_shape({a, b});
}
};
MIGRAPHX_REGISTER_OP(ck_gemm_add_add_gelu);
namespace {
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot")
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
if(a.lens().size() > 2 or b.lens().size() > 2)
return false;
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0);
}
struct find_ck_gemm
{
// Find a convolution followed by a pointwise operation.
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs());
}
};
struct find_ck_gemm_pointwise
{
auto matcher() const { return match::name("pointwise")(match::arg(0)(match::name("dot")(is_ck_gemm().bind("gemm")))); }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm = r.instructions["gemm"];
auto inputs = gemm->inputs();
for (auto in : ins->inputs())
{
if (in != gemm)
inputs.push_back(in);
}
mpm.get_module().replace_instruction(ins, ck_gemm_add_add_gelu{gemm->get_operator()}, inputs);
mpm.get_module().remove_instruction(gemm);
}
};
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const { match::find_matches(mpm, find_ck_gemm_pointwise{}); }
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
namespace gpu {
struct fuse_ck
{
context* ctx = nullptr;
std::string name() const { return "gpu::fuse_ck"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_fuse_ck : verify_program<test_fuse_ck>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
unsigned long m = 256;
unsigned long k = m;
unsigned long n = k;
migraphx::shape m1_shape{migraphx::shape::half_type, {m, k}};
migraphx::shape m2_shape{migraphx::shape::half_type, {k, n}};
migraphx::shape m3_shape{migraphx::shape::half_type, {m, n}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape);
auto l4 = mm->add_parameter("4", m3_shape);
auto gemm = mm->add_instruction(migraphx::make_op("dot"), l1, l2);
auto add = mm->add_instruction(migraphx::make_op("add"), gemm, l3);
auto x = mm->add_instruction(migraphx::make_op("add"), add, l4);
std::vector<size_t> input_lens{m, n};
migraphx::shape m4_shape{migraphx::shape::half_type, {1}};
auto half = mm->add_literal(migraphx::literal{m4_shape, {0.5}});
auto one = mm->add_literal(migraphx::literal{m4_shape, {1.0}});
auto sqrt2 = mm->add_literal(migraphx::literal{m4_shape, {M_SQRT2}});
auto half_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), half);
auto mul_half = mm->add_instruction(migraphx::make_op("mul"), x, half_mbcast);
auto sqrt2_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), sqrt2);
auto div = mm->add_instruction(migraphx::make_op("div"), x, sqrt2_mbcast);
auto erf = mm->add_instruction(migraphx::make_op("erf"), div);
auto one_mbcast = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", input_lens}}), one);
auto add_one = mm->add_instruction(migraphx::make_op("add"), erf, one_mbcast);
mm->add_instruction(migraphx::make_op("mul"), mul_half, add_one);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_ck_gemm : verify_program<test_ck_gemm>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
unsigned long m = 256;
unsigned long k = m;//4096;
unsigned long n = k;//4096;
migraphx::shape m1_shape{migraphx::shape::half_type, {m, k}};
migraphx::shape m2_shape{migraphx::shape::half_type, {k, n}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
// migraphx::shape m1_shape{migraphx::shape::half_type, {1}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {1}};
// auto l1 = mm->add_literal(migraphx::literal{m1_shape, {1}});
// auto l2 = mm->add_literal(migraphx::literal{m2_shape, {1}});
// l1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {m, k}}}), l1);
// l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {k, n}}}), l2);
mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
return p;
}
};
// struct test_ck_gemm : verify_program<test_ck_gemm>
// {
// migraphx::program create_program() const
// {
// migraphx::program p;
// auto* mm = p.get_main_module();
// unsigned long m = 3; unsigned long k = 3; unsigned long n = 3;
// migraphx::shape m1_shape{migraphx::shape::half_type, {m, k}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {k, n}};
// std::vector<float> v1(m * k, 1);
// //std::iota(v1.begin(), v1.end(), 1);
// std::vector<float> v2(k * n, 1);
// std::iota(v2.begin(), v2.end(), 1);
// auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
// auto l2 = mm->add_literal(migraphx::literal{m2_shape, v2});
// // auto l1 = mm->add_parameter("1", m1_shape);
// // auto l2 = mm->add_parameter("2", m2_shape);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
// mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
// return p;
// }
// };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment