Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3ec069ec
Commit
3ec069ec
authored
Mar 14, 2023
by
Alan Turner
Browse files
Formatting
parent
6a825932
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
20 deletions
+32
-20
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+32
-20
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
3ec069ec
...
...
@@ -38,7 +38,6 @@
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
...
...
@@ -64,7 +63,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_TUNING_VALUE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_DEBUG
);
// NOLINTNEXTLINE
static
const
char
*
const
ck_gemm_kernel
=
R"__migraphx__(
#include <args.hpp>
...
...
@@ -312,10 +310,12 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
const
bool
transCDE
=
transposed_matrix
(
c_shape
);
const
auto
a_type
=
get_type
(
a_shape
);
const
auto
b_type
=
get_type
(
b_shape
);
const
auto
cde_type
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_type
);
//get_type(c_shape);
const
auto
cde_type
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_type
);
// get_type(c_shape);
const
auto
cde_layout
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_layout
);
std
::
string
ck_passthrough
=
"ck_passthrough"
;
//"ck::tensor_operation::element_wise::PassThrough";
std
::
string
ck_passthrough
=
"ck_passthrough"
;
//"ck::tensor_operation::element_wise::PassThrough";
std
::
string
cde_op
=
ck_passthrough
;
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
if
(
v
.
contains
(
"post"
))
...
...
@@ -323,10 +323,22 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
cde_op
=
v
.
at
(
"post"
).
to
<
std
::
string
>
();
}
auto
problem
=
ck
::
tensor_operation
::
device
::
instance
::
Problem
{
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
numDTensors
),
static_cast
<
ck
::
index_t
>
(
tuning_value
),
transA
,
transB
,
transCDE
,
a_type
,
b_type
,
cde_type
,
ck_passthrough
,
ck_passthrough
,
cde_op
,
cde_layout
};
auto
problem
=
ck
::
tensor_operation
::
device
::
instance
::
Problem
{
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
numDTensors
),
static_cast
<
ck
::
index_t
>
(
tuning_value
),
transA
,
transB
,
transCDE
,
a_type
,
b_type
,
cde_type
,
ck_passthrough
,
ck_passthrough
,
cde_op
,
cde_layout
};
const
auto
solution
=
problem
.
GetSolution
();
auto
blocks_per_batch
=
problem
.
GetGridSize
();
auto
block_size
=
problem
.
GetBlockSize
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment