Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8752d11f
Commit
8752d11f
authored
Oct 18, 2022
by
Alan Turner
Browse files
Add specialization for lens not divisible by 8
parent
2d18473f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
4 deletions
+46
-4
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+21
-2
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+25
-2
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
8752d11f
...
...
@@ -54,12 +54,31 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
auto
m
=
b
.
lens
()[
1
];
auto
n
=
a
.
lens
()[
0
];
auto
k
=
a
.
lens
()[
1
];
if
(
a
.
lens
().
size
()
>
2
or
b
.
lens
().
size
()
>
2
)
return
false
;
if
(
a
.
lens
()[
1
]
>=
2048
)
return
false
;
return
(
a
.
lens
()[
0
]
%
8
==
0
and
a
.
lens
()[
1
]
%
8
==
0
and
b
.
lens
()[
0
]
%
8
==
0
and
b
.
lens
()[
1
]
%
8
==
0
);
return
true
;
// std::cout << a << std::endl;
// std::cout << b << std::endl;
// printf("m, n, k: %zu, %zu, %zu\n", m, n, k);
// if ((m == 1414 and n == 2048 and k == 512) or
// (m == 4096 and n == 2048 and k == 1414) or
// (m == 2048 and n == 2048 and k == 512) or
// (m == 2048 and n == 2048 and k == 512) or
// (m == 160 and n == 2048 and k == 64) or
// (m == 512 and n == 2048 and k == 512) or
// (m == 39488 and n == 2048 and k == 512) or
// (m == 5120 and n == 2048 and k == 512))
// return true;//(a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
// //b.lens()[1] % 8 == 0);
// return false;
}
struct
find_ck_gemm
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
8752d11f
...
...
@@ -40,7 +40,7 @@
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
const
std
::
vector
<
std
::
string
>&
std
::
vector
<
std
::
string
>&
get_instance
(
std
::
size_t
i
,
const
std
::
function
<
bool
(
const
std
::
vector
<
std
::
string
>&
)
>&
pred
);
namespace
migraphx
{
...
...
@@ -77,6 +77,8 @@ static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y -
static
std
::
size_t
block_size_index
=
13
;
static
std
::
size_t
padding_index
=
11
;
static
std
::
size_t
get_block_size
(
const
std
::
vector
<
std
::
string
>&
s
)
{
return
std
::
stoull
(
s
[
block_size_index
]);
...
...
@@ -89,6 +91,11 @@ static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t
return
int_div_ceil
(
m
,
mpb
)
*
int_div_ceil
(
n
,
npb
);
}
static
void
set_padding
(
std
::
vector
<
std
::
string
>&
s
,
const
std
::
string
p
)
{
s
[
padding_index
]
=
p
;
}
template
<
class
F
,
class
Action
>
auto
action_decorate
(
F
f
,
Action
action
)
{
...
...
@@ -111,6 +118,8 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
if
(
tuning
.
empty
())
std
::
cout
<<
"*********** Warning: No CK tuning!"
<<
std
::
endl
;
std
::
cout
<<
inputs
[
0
]
<<
std
::
endl
<<
inputs
[
1
]
<<
std
::
endl
;
auto
it
=
std
::
find_if
(
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
if
(
it
==
tuning
.
end
())
...
...
@@ -152,12 +161,26 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
sc
=
c_shape
.
strides
().
front
();
auto
i
=
v
.
get
(
"tuning_val"
,
get_tuning_for
(
inputs
));
const
auto
&
instance
=
get_instance
(
i
,
[
&
](
const
auto
&
x
)
->
bool
{
auto
&
instance
=
get_instance
(
i
,
[
&
](
const
auto
&
x
)
->
bool
{
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
get_layout
(
c_shape
)
==
x
[
2
]
and
get_type
(
a_shape
)
==
x
[
3
]
and
get_type
(
b_shape
)
==
x
[
4
]
and
get_type
(
c_shape
)
==
x
[
5
];
});
const
bool
pad_m
=
m
%
8
;
const
bool
pad_n
=
n
%
8
;
const
bool
pad_k
=
k
%
8
;
if
(
pad_m
or
pad_n
or
pad_k
)
{
std
::
string
padding_t
=
"ck::tensor_operation::device::GemmSpecialization::"
;
padding_t
+=
pad_m
?
"M"
:
""
;
padding_t
+=
pad_n
?
"N"
:
""
;
padding_t
+=
pad_k
?
"K"
:
""
;
padding_t
+=
"Padding"
;
set_padding
(
instance
,
padding_t
);
}
hip_compile_options
options
;
auto
block_size
=
get_block_size
(
instance
);
auto
grid_size
=
get_grid_size
(
instance
,
m
,
n
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment