Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
687a3310
"driver/olCompiling/include/logger.hpp" did not exist on "d2315b0dfcd6f31cca4328819eaf60d77e952dd6"
Commit
687a3310
authored
Dec 05, 2022
by
Paul
Browse files
Format
parent
24f0cb5b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
7 deletions
+4
-7
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+4
-7
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
687a3310
...
@@ -170,14 +170,11 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
...
@@ -170,14 +170,11 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
struct
ck_gemm_compiler
:
compiler
<
ck_gemm_compiler
>
struct
ck_gemm_compiler
:
compiler
<
ck_gemm_compiler
>
{
{
static
bool
transposed_matrix
(
const
shape
&
s
)
static
bool
transposed_matrix
(
const
shape
&
s
)
{
return
s
.
strides
().
back
()
!=
1
;
}
{
return
s
.
strides
().
back
()
!=
1
;
}
static
std
::
string
get_layout
(
const
shape
&
s
)
static
std
::
string
get_layout
(
const
shape
&
s
)
{
{
return
transposed_matrix
(
s
)
?
"ck::tensor_layout::gemm::ColumnMajor"
return
transposed_matrix
(
s
)
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
:
"ck::tensor_layout::gemm::RowMajor"
;
}
}
static
std
::
string
get_type
(
const
shape
&
s
)
static
std
::
string
get_type
(
const
shape
&
s
)
...
@@ -197,9 +194,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -197,9 +194,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
static
std
::
vector
<
shape
>
adjust_inputs
(
std
::
vector
<
shape
>
inputs
,
bool
&
swap_inputs
)
static
std
::
vector
<
shape
>
adjust_inputs
(
std
::
vector
<
shape
>
inputs
,
bool
&
swap_inputs
)
{
{
swap_inputs
=
false
;
swap_inputs
=
false
;
auto
c_shape
=
inputs
.
back
();
auto
c_shape
=
inputs
.
back
();
if
(
not
transposed_matrix
(
c_shape
))
if
(
not
transposed_matrix
(
c_shape
))
return
inputs
;
return
inputs
;
std
::
vector
<
int64_t
>
perm
(
c_shape
.
lens
().
size
());
std
::
vector
<
int64_t
>
perm
(
c_shape
.
lens
().
size
());
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
std
::
iota
(
perm
.
begin
(),
perm
.
end
(),
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment