Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a7d57049
Commit
a7d57049
authored
Oct 04, 2023
by
Alan Turner
Browse files
Move gemm_softmax_gemm matching to prefuse_ops
parent
c2bafa5d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
150 additions
and
113 deletions
+150
-113
src/include/migraphx/ck.hpp
src/include/migraphx/ck.hpp
+81
-0
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+8
-81
src/targets/gpu/include/migraphx/gpu/ck.hpp
src/targets/gpu/include/migraphx/gpu/ck.hpp
+2
-27
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+58
-2
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-3
No files found.
src/include/migraphx/ck.hpp
0 → 100644
View file @
a7d57049
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_CK_HPP
#define MIGRAPHX_GUARD_CK_HPP
#include <migraphx/env.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_LOG_CK_GEMM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_DEBUG
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TUNE_CK
);
#endif
struct
gemm_softmax_gemm
{
operation
op
=
make_op
(
"dot"
);
float
scale
=
1.0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
),
f
(
self
.
scale
,
"scale"
));
}
std
::
string
name
()
const
{
return
"pre_gemm_softmax_gemm"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm_softmax_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
if
(
inputs
.
size
()
<
3
)
MIGRAPHX_THROW
(
name
()
+
": Expected 3 inputs but got "
+
to_string
(
inputs
.
size
()));
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
auto
b1
=
inputs
[
2
];
for
(
const
auto
&
input
:
inputs
)
{
check_gemm_shape
(
input
);
}
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
}
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
},
t
);
}
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/fuse_ck.cpp
View file @
a7d57049
...
...
@@ -24,8 +24,8 @@
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/ck.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -73,41 +73,9 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
struct
ck_gemm_softmax_gemm
struct
ck_gemm_softmax_gemm
:
gemm_softmax_gemm
{
operation
op
=
make_op
(
"dot"
);
float
scale
=
1.0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
),
f
(
self
.
scale
,
"scale"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_softmax_gemm"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm_softmax_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
if
(
inputs
.
size
()
<
3
)
MIGRAPHX_THROW
(
name
()
+
": Expected 3 inputs but got "
+
to_string
(
inputs
.
size
()));
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
auto
b1
=
inputs
[
2
];
for
(
const
auto
&
input
:
inputs
)
{
check_gemm_shape
(
input
);
}
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
}
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
},
t
);
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_softmax_gemm
);
...
...
@@ -203,61 +171,21 @@ struct find_ck_gemm
}
};
auto
is_mul_module
(
module
&
m
)
{
auto
is_mul
=
match
::
arg
(
0
)(
match
::
name
(
"mul"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
name
(
"@param"
))));
return
match_instruction
(
m
,
std
::
prev
(
m
.
end
()),
is_mul
).
result
!=
m
.
end
();
}
MIGRAPHX_PRED_MATCHER
(
is_pointwise_scale
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"pointwise"
)
return
false
;
if
(
ins
->
module_inputs
().
size
()
!=
1
)
return
false
;
return
is_mul_module
(
*
ins
->
module_inputs
().
front
());
}
struct
find_ck_gemm_softmax_gemm
{
auto
matcher
()
const
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
mul
=
match
::
name
(
"pointwise"
)(
match
::
nargs
(
2
),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"scale"
),
gemm1
))(
is_pointwise_scale
());
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
arg
(
0
)(
mul
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
arg
(
0
)(
softmax
));
return
match
::
name
(
"gpu::pre_gemm_softmax_gemm"
);
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm2_ins
=
r
.
instructions
[
"gemm2"
];
auto
gemm1_ins
=
r
.
instructions
[
"gemm1"
];
auto
scale_lit
=
r
.
instructions
[
"scale"
];
if
(
not
ck_gemm_softmax_gemm
::
is_ck_supported_type
(
gemm1_ins
->
get_shape
().
type
()))
return
;
float
scale
=
1.0
;
scale_lit
->
eval
().
visit
([
&
](
const
auto
s
)
{
// CK only supports single-valued scale
if
(
std
::
all_of
(
s
.
begin
()
+
1
,
s
.
end
(),
[
&
](
auto
v
)
{
return
float_equal
(
v
,
s
.
front
());
}))
scale
=
s
.
front
();
else
return
;
});
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
auto
v
=
ins
->
get_operator
().
to_value
();
assert
(
v
.
contains
(
"scale"
));
auto
scale
=
v
.
at
(
"scale"
).
to
<
float
>
();
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_softmax_gemm
{
gemm2_ins
->
get_operator
(
),
scale
},
inputs
);
ins
,
ck_gemm_softmax_gemm
{
migraphx
::
make_op
(
"dot"
),
scale
},
ins
->
inputs
()
);
}
};
...
...
@@ -265,8 +193,7 @@ struct find_ck_gemm_softmax_gemm
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_ck_gemm_softmax_gemm
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_softmax_gemm
{},
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
}
...
...
src/targets/gpu/include/migraphx/gpu/ck.hpp
View file @
a7d57049
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_GPU_CK_HPP
#define MIGRAPHX_GUARD_GPU_CK_HPP
#include <migraphx/c
onfig
.hpp>
#include <migraphx/c
k
.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/compile_src.hpp>
...
...
@@ -35,10 +35,6 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_LOG_CK_GEMM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_DEBUG
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TUNE_CK
);
// NOLINTNEXTLINE
const
char
*
const
disable_warning_pragma
=
R"__migraphx__(
#pragma clang diagnostic push
...
...
@@ -78,7 +74,7 @@ static std::vector<src_file> create_ck_headers()
return
srcs
;
}
static
const
std
::
vector
<
src_file
>&
ck_headers
()
static
inline
const
std
::
vector
<
src_file
>&
ck_headers
()
{
static
const
auto
&
headers
=
create_ck_headers
();
return
headers
;
...
...
@@ -86,27 +82,6 @@ static const std::vector<src_file>& ck_headers()
inline
bool
transposed_matrix
(
const
shape
&
s
)
{
return
s
.
strides
().
back
()
!=
1
;
}
inline
float
matrix_distance
(
const
shape
&
x
,
const
shape
&
y
)
{
if
(
x
.
type
()
!=
y
.
type
())
return
std
::
numeric_limits
<
float
>::
max
();
if
(
transposed_matrix
(
x
)
!=
transposed_matrix
(
y
))
return
std
::
numeric_limits
<
float
>::
max
();
auto
sum_squared
=
std
::
inner_product
(
x
.
lens
().
rbegin
(),
x
.
lens
().
rbegin
()
+
2
,
y
.
lens
().
rbegin
(),
0
,
std
::
plus
<>
{},
[](
auto
a
,
auto
b
)
{
return
(
a
-
b
)
*
(
a
-
b
);
});
return
std
::
sqrt
(
sum_squared
);
}
inline
std
::
string
get_layout
(
const
shape
&
s
)
{
return
transposed_matrix
(
s
)
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
inline
ck
::
host
::
DataType
get_type
(
const
shape
&
s
)
{
if
(
s
.
type
()
==
shape
::
half_type
)
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
a7d57049
...
...
@@ -24,15 +24,15 @@
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/ck.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
{
template
<
class
Derived
,
std
::
size_t
N
>
...
...
@@ -120,6 +120,60 @@ struct find_add_layernorm
m
.
replace_instruction
(
ins
,
add_layernorm
{
op
.
epsilon
},
add_ins
->
inputs
());
}
};
struct
pre_gemm_softmax_gemm
:
gemm_softmax_gemm
{
std
::
string
name
()
const
{
return
"gpu::pre_gemm_softmax_gemm"
;
}
};
MIGRAPHX_REGISTER_OP
(
pre_gemm_softmax_gemm
);
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
if
(
not
pre_gemm_softmax_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
return
true
;
}
struct
find_gemm_softmax_gemm
{
auto
matcher
()
const
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
mul
=
match
::
name
(
"mul"
)(
match
::
nargs
(
2
),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"scale"
),
gemm1
));
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
arg
(
0
)(
mul
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
arg
(
0
)(
softmax
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm2_ins
=
r
.
instructions
[
"gemm2"
];
auto
gemm1_ins
=
r
.
instructions
[
"gemm1"
];
auto
scale_lit
=
r
.
instructions
[
"scale"
];
float
scale
=
1.0
;
scale_lit
->
eval
().
visit
([
&
](
const
auto
s
)
{
// CK only supports single-valued scale
if
(
std
::
all_of
(
s
.
begin
()
+
1
,
s
.
end
(),
[
&
](
auto
v
)
{
return
float_equal
(
v
,
s
.
front
());
}))
scale
=
s
.
front
();
else
return
;
});
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
mpm
.
get_module
().
replace_instruction
(
ins
,
pre_gemm_softmax_gemm
{
gemm2_ins
->
get_operator
(),
scale
},
inputs
);
}
};
}
// namespace
void
prefuse_ops
::
apply
(
module_pass_manager
&
mpm
)
const
...
...
@@ -127,6 +181,8 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match
::
find_matches
(
mpm
.
get_module
(),
find_layernorm
{});
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_add_layernorm
{});
if
(
enabled
(
MIGRAPHX_ENABLE_CK
{}))
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm
{});
}
}
// namespace gpu
...
...
src/targets/gpu/target.cpp
View file @
a7d57049
...
...
@@ -53,6 +53,7 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/ck.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/gpu/concat_gpu_opt.hpp>
...
...
@@ -76,9 +77,6 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_SCHEDULE_PASS
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_REDUCE_FUSION
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_NHWC
)
#ifndef _WIN32
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_CK
)
#endif
struct
id_pass
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment