Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
74f21ca6
Commit
74f21ca6
authored
May 26, 2023
by
Paul
Browse files
Refactor
parent
674b3bac
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
12 deletions
+36
-12
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+5
-3
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+31
-9
No files found.
src/targets/gpu/compile_ops.cpp
View file @
74f21ca6
...
@@ -112,11 +112,13 @@ struct compile_plan
...
@@ -112,11 +112,13 @@ struct compile_plan
{
{
if
(
results
.
size
()
==
1
)
if
(
results
.
size
()
==
1
)
return
results
.
front
();
return
results
.
front
();
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
<<
std
::
endl
;
std
::
cout
<<
"Benchmarking "
<<
preop
.
name
()
<<
": "
<<
results
.
size
()
<<
" configs"
<<
std
::
endl
;
std
::
vector
<
double
>
times
;
std
::
vector
<
double
>
times
;
for
(
const
auto
&
cr
:
results
)
for
(
const
auto
&
cr
:
results
)
{
{
times
.
push_back
(
time_op
(
*
ctx
,
cr
.
replace
.
code_object
,
to_shapes
(
cr
.
ins
->
inputs
()),
20
).
first
);
times
.
push_back
(
time_op
(
*
ctx
,
cr
.
replace
.
code_object
,
to_shapes
(
cr
.
ins
->
inputs
()),
20
).
first
);
}
}
auto
i
=
std
::
distance
(
times
.
begin
(),
std
::
min_element
(
times
.
begin
(),
times
.
end
()));
auto
i
=
std
::
distance
(
times
.
begin
(),
std
::
min_element
(
times
.
begin
(),
times
.
end
()));
return
results
[
i
];
return
results
[
i
];
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
74f21ca6
...
@@ -267,20 +267,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -267,20 +267,28 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
bool
can_fold_batch
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
rank
=
a_shape
.
lens
().
size
();
auto
b_strides
=
b_shape
.
strides
();
return
rank
>=
3
and
b_strides
[
rank
-
3
]
==
0
;
}
ck
::
host
::
device_gemm_multiple_d
::
Problem
create_problem
(
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
auto
a_shape
=
inputs
[
0
];
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
b_shape
=
inputs
[
1
];
auto
c_shape
=
inputs
.
back
();
auto
c_shape
=
inputs
.
back
();
auto
tuning_value
=
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
});
auto
rank
=
a_shape
.
lens
().
size
();
auto
rank
=
a_shape
.
lens
().
size
();
auto
b_strides
=
b_shape
.
strides
();
auto
b_strides
=
b_shape
.
strides
();
bool
can_fold_batch
=
rank
>=
3
and
b_strides
[
rank
-
3
]
==
0
;
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
m
=
can_fold_batch
?
m
*
batch_count
:
m
;
m
=
can_fold_batch
(
inputs
)
?
m
*
batch_count
:
m
;
auto
n
=
c_shape
.
lens
().
back
();
auto
n
=
c_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
...
@@ -309,7 +317,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -309,7 +317,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
cde_op
=
v
.
at
(
"post"
).
to
<
std
::
string
>
();
cde_op
=
v
.
at
(
"post"
).
to
<
std
::
string
>
();
}
}
auto
problem
=
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
m
,
return
ck
::
host
::
device_gemm_multiple_d
::
Problem
{
m
,
n
,
n
,
k
,
k
,
transA
,
transA
,
...
@@ -323,6 +331,16 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -323,6 +331,16 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ck_passthrough
,
ck_passthrough
,
ck_passthrough
,
ck_passthrough
,
cde_op
};
cde_op
};
}
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
c_shape
=
inputs
.
back
();
auto
tuning_value
=
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
});
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
problem
=
create_problem
(
inputs
,
v
);
const
auto
include_header
=
problem
.
GetIncludeHeader
();
const
auto
include_header
=
problem
.
GetIncludeHeader
();
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
const
auto
solutions
=
problem
.
GetSolutions
(
"gfx90a"
);
...
@@ -333,13 +351,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -333,13 +351,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
hip_compile_options
options
;
hip_compile_options
options
;
options
.
additional_src_files
=
ck_headers
();
options
.
additional_src_files
=
ck_headers
();
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
auto
grid_size
=
can_fold_batch
(
inputs
)
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
c_shape
;
options
.
output
=
c_shape
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
virtual_inputs
=
inputs
;
options
.
virtual_inputs
=
inputs
;
if
(
can_fold_batch
)
if
(
can_fold_batch
(
inputs
)
)
{
{
auto
vinputs
=
inputs
;
auto
vinputs
=
inputs
;
fold_batch_dims
(
vinputs
[
0
]);
fold_batch_dims
(
vinputs
[
0
]);
...
@@ -363,7 +381,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -363,7 +381,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
value
create_settings
(
instruction_ref
ins
,
const
operation
&
op
)
const
{
{
auto
v
=
op
.
to_value
();
auto
v
=
op
.
to_value
();
v
[
"kernel"
]
=
"ck_gemm_kernel"
;
v
[
"kernel"
]
=
"ck_gemm_kernel"
;
...
@@ -375,9 +393,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -375,9 +393,13 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm>"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm>"
;
v
[
"kernel"
]
=
"ck_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
v
[
"kernel"
]
=
"ck_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
}
return
v
;
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
shapes
=
to_shapes
(
ins
->
inputs
());
auto
shapes
=
to_shapes
(
ins
->
inputs
());
return
{
compile_op
(
ctx
,
shapes
,
v
),
return
{
compile_op
(
ctx
,
shapes
,
create_settings
(
ins
,
op
)
),
[
=
](
module
&
m
,
instruction_ref
ins2
,
const
operation
&
code_object
)
{
[
=
](
module
&
m
,
instruction_ref
ins2
,
const
operation
&
code_object
)
{
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment