Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0e889dc7
Commit
0e889dc7
authored
Oct 18, 2022
by
Alan Turner
Browse files
Formatting
parent
8752d11f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
11 deletions
+7
-11
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+3
-3
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+4
-8
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
0e889dc7
...
@@ -59,15 +59,15 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...
@@ -59,15 +59,15 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
auto
k
=
a
.
lens
()[
1
];
auto
k
=
a
.
lens
()[
1
];
if
(
a
.
lens
().
size
()
>
2
or
b
.
lens
().
size
()
>
2
)
if
(
a
.
lens
().
size
()
>
2
or
b
.
lens
().
size
()
>
2
)
return
false
;
return
false
;
if
(
a
.
lens
()[
1
]
>=
2048
)
if
(
a
.
lens
()[
1
]
>=
2048
)
return
false
;
return
false
;
return
true
;
return
true
;
// std::cout << a << std::endl;
// std::cout << a << std::endl;
// std::cout << b << std::endl;
// std::cout << b << std::endl;
// printf("m, n, k: %zu, %zu, %zu\n", m, n, k);
// printf("m, n, k: %zu, %zu, %zu\n", m, n, k);
// if ((m == 1414 and n == 2048 and k == 512) or
// if ((m == 1414 and n == 2048 and k == 512) or
// (m == 4096 and n == 2048 and k == 1414) or
// (m == 4096 and n == 2048 and k == 1414) or
// (m == 2048 and n == 2048 and k == 512) or
// (m == 2048 and n == 2048 and k == 512) or
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
0e889dc7
...
@@ -91,10 +91,7 @@ static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t
...
@@ -91,10 +91,7 @@ static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t
return
int_div_ceil
(
m
,
mpb
)
*
int_div_ceil
(
n
,
npb
);
return
int_div_ceil
(
m
,
mpb
)
*
int_div_ceil
(
n
,
npb
);
}
}
static
void
set_padding
(
std
::
vector
<
std
::
string
>&
s
,
const
std
::
string
p
)
static
void
set_padding
(
std
::
vector
<
std
::
string
>&
s
,
const
std
::
string
p
)
{
s
[
padding_index
]
=
p
;
}
{
s
[
padding_index
]
=
p
;
}
template
<
class
F
,
class
Action
>
template
<
class
F
,
class
Action
>
auto
action_decorate
(
F
f
,
Action
action
)
auto
action_decorate
(
F
f
,
Action
action
)
...
@@ -118,8 +115,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
...
@@ -118,8 +115,7 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
if
(
tuning
.
empty
())
if
(
tuning
.
empty
())
std
::
cout
<<
"*********** Warning: No CK tuning!"
<<
std
::
endl
;
std
::
cout
<<
"*********** Warning: No CK tuning!"
<<
std
::
endl
;
std
::
cout
<<
inputs
[
0
]
<<
std
::
endl
std
::
cout
<<
inputs
[
0
]
<<
std
::
endl
<<
inputs
[
1
]
<<
std
::
endl
;
<<
inputs
[
1
]
<<
std
::
endl
;
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
if
(
it
==
tuning
.
end
())
if
(
it
==
tuning
.
end
())
...
@@ -160,7 +156,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -160,7 +156,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
sb
=
b_shape
.
strides
().
front
();
auto
sb
=
b_shape
.
strides
().
front
();
auto
sc
=
c_shape
.
strides
().
front
();
auto
sc
=
c_shape
.
strides
().
front
();
auto
i
=
v
.
get
(
"tuning_val"
,
get_tuning_for
(
inputs
));
auto
i
=
v
.
get
(
"tuning_val"
,
get_tuning_for
(
inputs
));
auto
&
instance
=
get_instance
(
i
,
[
&
](
const
auto
&
x
)
->
bool
{
auto
&
instance
=
get_instance
(
i
,
[
&
](
const
auto
&
x
)
->
bool
{
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
get_layout
(
c_shape
)
==
x
[
2
]
and
get_type
(
a_shape
)
==
x
[
3
]
and
get_layout
(
c_shape
)
==
x
[
2
]
and
get_type
(
a_shape
)
==
x
[
3
]
and
...
@@ -171,7 +167,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -171,7 +167,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const
bool
pad_n
=
n
%
8
;
const
bool
pad_n
=
n
%
8
;
const
bool
pad_k
=
k
%
8
;
const
bool
pad_k
=
k
%
8
;
if
(
pad_m
or
pad_n
or
pad_k
)
if
(
pad_m
or
pad_n
or
pad_k
)
{
{
std
::
string
padding_t
=
"ck::tensor_operation::device::GemmSpecialization::"
;
std
::
string
padding_t
=
"ck::tensor_operation::device::GemmSpecialization::"
;
padding_t
+=
pad_m
?
"M"
:
""
;
padding_t
+=
pad_m
?
"M"
:
""
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment