Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3ec069ec
Commit
3ec069ec
authored
Mar 14, 2023
by
Alan Turner
Browse files
Formatting
parent
6a825932
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
20 deletions
+32
-20
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+32
-20
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
3ec069ec
...
...
@@ -38,7 +38,6 @@
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
...
...
@@ -64,7 +63,6 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_TUNING_VALUE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_DEBUG
);
// NOLINTNEXTLINE
static
const
char
*
const
ck_gemm_kernel
=
R"__migraphx__(
#include <args.hpp>
...
...
@@ -291,9 +289,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
c_shape
=
inputs
.
back
();
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
c_shape
=
inputs
.
back
();
auto
tuning_value
=
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
});
auto
rank
=
a_shape
.
lens
().
size
();
...
...
@@ -307,15 +305,17 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
k
=
a_shape
.
lens
().
back
();
const
auto
numDTensors
=
inputs
.
size
()
-
3
;
const
bool
transA
=
transposed_matrix
(
a_shape
);
const
bool
transB
=
transposed_matrix
(
b_shape
);
const
bool
transCDE
=
transposed_matrix
(
c_shape
);
const
auto
a_type
=
get_type
(
a_shape
);
const
auto
b_type
=
get_type
(
b_shape
);
const
auto
cde_type
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_type
);
//get_type(c_shape);
const
bool
transA
=
transposed_matrix
(
a_shape
);
const
bool
transB
=
transposed_matrix
(
b_shape
);
const
bool
transCDE
=
transposed_matrix
(
c_shape
);
const
auto
a_type
=
get_type
(
a_shape
);
const
auto
b_type
=
get_type
(
b_shape
);
const
auto
cde_type
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_type
);
// get_type(c_shape);
const
auto
cde_layout
=
ck_tuple
(
inputs
.
begin
()
+
2
,
inputs
.
end
()
-
1
,
&
get_layout
);
std
::
string
ck_passthrough
=
"ck_passthrough"
;
//"ck::tensor_operation::element_wise::PassThrough";
std
::
string
ck_passthrough
=
"ck_passthrough"
;
//"ck::tensor_operation::element_wise::PassThrough";
std
::
string
cde_op
=
ck_passthrough
;
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
if
(
v
.
contains
(
"post"
))
...
...
@@ -323,16 +323,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
cde_op
=
v
.
at
(
"post"
).
to
<
std
::
string
>
();
}
auto
problem
=
ck
::
tensor_operation
::
device
::
instance
::
Problem
{
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
numDTensors
),
static_cast
<
ck
::
index_t
>
(
tuning_value
),
transA
,
transB
,
transCDE
,
a_type
,
b_type
,
cde_type
,
ck_passthrough
,
ck_passthrough
,
cde_op
,
cde_layout
};
const
auto
solution
=
problem
.
GetSolution
();
auto
problem
=
ck
::
tensor_operation
::
device
::
instance
::
Problem
{
static_cast
<
ck
::
index_t
>
(
m
),
static_cast
<
ck
::
index_t
>
(
n
),
static_cast
<
ck
::
index_t
>
(
k
),
static_cast
<
ck
::
index_t
>
(
numDTensors
),
static_cast
<
ck
::
index_t
>
(
tuning_value
),
transA
,
transB
,
transCDE
,
a_type
,
b_type
,
cde_type
,
ck_passthrough
,
ck_passthrough
,
cde_op
,
cde_layout
};
const
auto
solution
=
problem
.
GetSolution
();
auto
blocks_per_batch
=
problem
.
GetGridSize
();
auto
block_size
=
problem
.
GetBlockSize
();
auto
block_size
=
problem
.
GetBlockSize
();
hip_compile_options
options
;
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
options
.
output
=
c_shape
;
...
...
@@ -349,7 +361,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
if
(
v
.
get
(
"check"
,
false
)
or
enabled
(
MIGRAPHX_CK_DEBUG
{}))
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
auto
src
=
interpolate_string
(
ck_gemm_kernel
,
{{
"solution"
,
solution
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment