Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
873f6c0c
Commit
873f6c0c
authored
Oct 08, 2022
by
Paul
Browse files
Format
parent
6ad2af4e
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8058 additions
and
26 deletions
+8058
-26
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+2
-2
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+15
-17
src/targets/gpu/jit/ck_gemm_instances.hpp
src/targets/gpu/jit/ck_gemm_instances.hpp
+8041
-7
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
873f6c0c
...
@@ -25,7 +25,7 @@ struct ck_gemm
...
@@ -25,7 +25,7 @@ struct ck_gemm
void
check_gemm_shape
(
const
shape
&
s
)
const
void
check_gemm_shape
(
const
shape
&
s
)
const
{
{
if
(
contains
(
s
.
lens
(),
1
))
if
(
contains
(
s
.
lens
(),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm"
);
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm"
);
}
}
...
@@ -54,7 +54,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
...
@@ -54,7 +54,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return
false
;
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
if
(
a
.
lens
().
size
()
>
2
or
b
.
lens
().
size
()
>
2
)
if
(
a
.
lens
().
size
()
>
2
or
b
.
lens
().
size
()
>
2
)
return
false
;
return
false
;
return
(
a
.
lens
()[
0
]
%
8
==
0
and
a
.
lens
()[
1
]
%
8
==
0
and
b
.
lens
()[
0
]
%
8
==
0
and
return
(
a
.
lens
()[
0
]
%
8
==
0
and
a
.
lens
()[
1
]
%
8
==
0
and
b
.
lens
()[
0
]
%
8
==
0
and
b
.
lens
()[
1
]
%
8
==
0
);
b
.
lens
()[
1
]
%
8
==
0
);
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
873f6c0c
...
@@ -90,8 +90,8 @@ static std::size_t get_block_size(const std::vector<std::string>& s)
...
@@ -90,8 +90,8 @@ static std::size_t get_block_size(const std::vector<std::string>& s)
static
std
::
size_t
get_grid_size
(
const
std
::
vector
<
std
::
string
>&
s
,
std
::
size_t
m
,
std
::
size_t
n
)
static
std
::
size_t
get_grid_size
(
const
std
::
vector
<
std
::
string
>&
s
,
std
::
size_t
m
,
std
::
size_t
n
)
{
{
auto
mpb
=
std
::
stoull
(
s
[
block_size_index
+
1
]);
auto
mpb
=
std
::
stoull
(
s
[
block_size_index
+
1
]);
auto
npb
=
std
::
stoull
(
s
[
block_size_index
+
2
]);
auto
npb
=
std
::
stoull
(
s
[
block_size_index
+
2
]);
return
int_div_ceil
(
m
,
mpb
)
*
int_div_ceil
(
n
,
npb
);
return
int_div_ceil
(
m
,
mpb
)
*
int_div_ceil
(
n
,
npb
);
}
}
...
@@ -99,12 +99,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -99,12 +99,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
{
static
std
::
string
get_layout
(
const
shape
&
s
)
static
std
::
string
get_layout
(
const
shape
&
s
)
{
{
return
s
.
transposed
()
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
return
s
.
transposed
()
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
}
static
std
::
string
get_type
(
const
shape
&
s
)
static
std
::
string
get_type
(
const
shape
&
s
)
{
{
if
(
s
.
type
()
==
shape
::
half_type
)
if
(
s
.
type
()
==
shape
::
half_type
)
return
"ck::half_t"
;
return
"ck::half_t"
;
return
shape
::
cpp_type
(
s
.
type
());
return
shape
::
cpp_type
(
s
.
type
());
}
}
...
@@ -117,21 +118,18 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -117,21 +118,18 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
b_shape
=
inputs
[
1
];
auto
b_shape
=
inputs
[
1
];
auto
c_shape
=
inputs
[
2
];
auto
c_shape
=
inputs
[
2
];
auto
m
=
c_shape
.
lens
().
front
();
auto
m
=
c_shape
.
lens
().
front
();
auto
n
=
c_shape
.
lens
().
back
();
auto
n
=
c_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
auto
sa
=
a_shape
.
strides
().
front
();
auto
sa
=
a_shape
.
strides
().
front
();
auto
sb
=
b_shape
.
strides
().
front
();
auto
sb
=
b_shape
.
strides
().
front
();
auto
sc
=
c_shape
.
strides
().
front
();
auto
sc
=
c_shape
.
strides
().
front
();
int
i
=
v
.
get
(
"tuning_val"
,
4
);
int
i
=
v
.
get
(
"tuning_val"
,
4
);
const
auto
&
instance
=
get_instance
(
i
,
[
&
](
const
auto
&
x
)
->
bool
{
const
auto
&
instance
=
get_instance
(
i
,
[
&
](
const
auto
&
x
)
->
bool
{
return
get_layout
(
a_shape
)
==
x
[
0
]
and
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
get_layout
(
c_shape
)
==
x
[
2
]
and
get_type
(
a_shape
)
==
x
[
3
]
and
get_layout
(
c_shape
)
==
x
[
2
]
and
get_type
(
b_shape
)
==
x
[
4
]
and
get_type
(
c_shape
)
==
x
[
5
];
get_type
(
a_shape
)
==
x
[
3
]
and
get_type
(
b_shape
)
==
x
[
4
]
and
get_type
(
c_shape
)
==
x
[
5
];
});
});
hip_compile_options
options
;
hip_compile_options
options
;
...
...
src/targets/gpu/jit/ck_gemm_instances.hpp
View file @
873f6c0c
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment