Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2e79bb1b
Commit
2e79bb1b
authored
Oct 12, 2022
by
Alan Turner
Browse files
remove debug prints from fuse_ck
parent
f83139de
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
313 additions
and
0 deletions
+313
-0
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+137
-0
src/targets/gpu/include/migraphx/gpu/fuse_ck.hpp
src/targets/gpu/include/migraphx/gpu/fuse_ck.hpp
+26
-0
test/verify/0_test_fuse_ck.cpp
test/verify/0_test_fuse_ck.cpp
+70
-0
test/verify/0ck_test_ck_gemm.cpp
test/verify/0ck_test_ck_gemm.cpp
+80
-0
No files found.
src/targets/gpu/fuse_ck.cpp
0 → 100644
View file @
2e79bb1b
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
namespace
gpu
{
struct
ck_gemm
{
operation
op
=
make_op
(
"dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
contains
(
s
.
lens
(),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
not_broadcasted
();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
n
=
inputs
.
size
();
auto
a
=
inputs
[
n
-
2
];
auto
b
=
inputs
[
n
-
1
];
check_gemm_shape
(
a
);
check_gemm_shape
(
b
);
return
op
.
compute_shape
({
a
,
b
});
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
struct
ck_gemm_add_add_gelu
{
operation
op
=
make_op
(
"dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_add_add_gelu"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
contains
(
s
.
lens
(),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
not_broadcasted
();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
n
=
inputs
.
size
();
auto
a
=
inputs
[
n
-
2
];
auto
b
=
inputs
[
n
-
1
];
check_gemm_shape
(
a
);
check_gemm_shape
(
b
);
return
op
.
compute_shape
({
a
,
b
});
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_add_add_gelu
);
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
if
(
a
.
lens
().
size
()
>
2
or
b
.
lens
().
size
()
>
2
)
return
false
;
return
(
a
.
lens
()[
0
]
%
8
==
0
and
a
.
lens
()[
1
]
%
8
==
0
and
b
.
lens
()[
0
]
%
8
==
0
and
b
.
lens
()[
1
]
%
8
==
0
);
}
struct
find_ck_gemm
{
// Find a convolution followed by a pointwise operation.
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm
{
ins
->
get_operator
()},
ins
->
inputs
());
}
};
struct
find_ck_gemm_pointwise
{
auto
matcher
()
const
{
return
match
::
name
(
"pointwise"
)(
match
::
arg
(
0
)(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
))));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm
=
r
.
instructions
[
"gemm"
];
auto
inputs
=
gemm
->
inputs
();
for
(
auto
in
:
ins
->
inputs
())
{
if
(
in
!=
gemm
)
inputs
.
push_back
(
in
);
}
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_add_add_gelu
{
gemm
->
get_operator
()},
inputs
);
mpm
.
get_module
().
remove_instruction
(
gemm
);
}
};
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/include/migraphx/gpu/fuse_ck.hpp
0 → 100644
View file @
2e79bb1b
#ifndef MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module_pass_manager
;
namespace
gpu
{
struct
fuse_ck
{
context
*
ctx
=
nullptr
;
std
::
string
name
()
const
{
return
"gpu::fuse_ck"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
test/verify/0_test_fuse_ck.cpp
0 → 100644
View file @
2e79bb1b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_fuse_ck
:
verify_program
<
test_fuse_ck
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
unsigned
long
m
=
256
;
unsigned
long
k
=
m
;
unsigned
long
n
=
k
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
m
,
k
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
k
,
n
}};
migraphx
::
shape
m3_shape
{
migraphx
::
shape
::
half_type
,
{
m
,
n
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
auto
l3
=
mm
->
add_parameter
(
"3"
,
m3_shape
);
auto
l4
=
mm
->
add_parameter
(
"4"
,
m3_shape
);
auto
gemm
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
l1
,
l2
);
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
gemm
,
l3
);
auto
x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
add
,
l4
);
std
::
vector
<
size_t
>
input_lens
{
m
,
n
};
migraphx
::
shape
m4_shape
{
migraphx
::
shape
::
half_type
,
{
1
}};
auto
half
=
mm
->
add_literal
(
migraphx
::
literal
{
m4_shape
,
{
0.5
}});
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{
m4_shape
,
{
1.0
}});
auto
sqrt2
=
mm
->
add_literal
(
migraphx
::
literal
{
m4_shape
,
{
M_SQRT2
}});
auto
half_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
half
);
auto
mul_half
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
x
,
half_mbcast
);
auto
sqrt2_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
sqrt2
);
auto
div
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"div"
),
x
,
sqrt2_mbcast
);
auto
erf
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"erf"
),
div
);
auto
one_mbcast
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_lens
}}),
one
);
auto
add_one
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
erf
,
one_mbcast
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
mul_half
,
add_one
);
return
p
;
}
};
test/verify/0ck_test_ck_gemm.cpp
0 → 100644
View file @
2e79bb1b
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_ck_gemm
:
verify_program
<
test_ck_gemm
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
unsigned
long
m
=
256
;
unsigned
long
k
=
m
;
//4096;
unsigned
long
n
=
k
;
//4096;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
m
,
k
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
k
,
n
}};
auto
l1
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
l2
=
mm
->
add_parameter
(
"2"
,
m2_shape
);
// migraphx::shape m1_shape{migraphx::shape::half_type, {1}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {1}};
// auto l1 = mm->add_literal(migraphx::literal{m1_shape, {1}});
// auto l2 = mm->add_literal(migraphx::literal{m2_shape, {1}});
// l1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {m, k}}}), l1);
// l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {k, n}}}), l2);
mm
->
add_instruction
(
migraphx
::
make_op
(
"ck_gemm"
),
l1
,
l2
);
return
p
;
}
};
// struct test_ck_gemm : verify_program<test_ck_gemm>
// {
// migraphx::program create_program() const
// {
// migraphx::program p;
// auto* mm = p.get_main_module();
// unsigned long m = 3; unsigned long k = 3; unsigned long n = 3;
// migraphx::shape m1_shape{migraphx::shape::half_type, {m, k}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {k, n}};
// std::vector<float> v1(m * k, 1);
// //std::iota(v1.begin(), v1.end(), 1);
// std::vector<float> v2(k * n, 1);
// std::iota(v2.begin(), v2.end(), 1);
// auto l1 = mm->add_literal(migraphx::literal{m1_shape, v1});
// auto l2 = mm->add_literal(migraphx::literal{m2_shape, v2});
// // auto l1 = mm->add_parameter("1", m1_shape);
// // auto l2 = mm->add_parameter("2", m2_shape);
// // l1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l1);
// // l2 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l2);
// mm->add_instruction(migraphx::make_op("ck_gemm"), l1, l2);
// return p;
// }
// };
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment