Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
90db83e7
Commit
90db83e7
authored
Oct 31, 2023
by
Paul
Browse files
Format
parent
8e20f747
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
50 additions
and
21 deletions
+50
-21
host/include/ck/host/device_gemm_multiple_d.hpp
host/include/ck/host/device_gemm_multiple_d.hpp
+5
-1
host/include/ck/host/device_gemm_multiple_d/problem.hpp
host/include/ck/host/device_gemm_multiple_d/problem.hpp
+4
-0
host/include/ck/host/utils.hpp
host/include/ck/host/utils.hpp
+3
-0
host/src/utils.cpp
host/src/utils.cpp
+6
-0
host/test/gemm_multiple_d.cpp
host/test/gemm_multiple_d.cpp
+32
-20
No files found.
host/include/ck/host/device_gemm_multiple_d.hpp
View file @
90db83e7
...
...
@@ -9,7 +9,7 @@
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/
common
.hpp"
#include "ck/host/
types
.hpp"
namespace
ck
{
namespace
host
{
...
...
@@ -31,6 +31,10 @@ struct Problem
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
};
}
// namespace device_gemm_multiple_d
...
...
host/include/ck/host/device_gemm_multiple_d/problem.hpp
View file @
90db83e7
...
...
@@ -28,6 +28,10 @@ struct Problem
std
::
string
AElementOp
=
PassThrough
;
std
::
string
BElementOp
=
PassThrough
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
};
}
// namespace device_gemm_multiple_d
...
...
host/include/ck/host/utils.hpp
View file @
90db83e7
...
...
@@ -4,11 +4,14 @@
#pragma once
#include <cstdint>
#include <unordered_set>
namespace
ck
{
namespace
host
{
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
();
}
// namespace host
}
// namespace ck
host/src/utils.cpp
View file @
90db83e7
...
...
@@ -11,5 +11,11 @@ std::size_t integer_divide_ceil(std::size_t x, std::size_t y)
return
(
x
+
y
-
std
::
size_t
{
1
})
/
y
;
}
const
std
::
unordered_set
<
std
::
string
>&
get_xdlop_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx90a"
,
"gfx908"
,
"gfx940"
,
"gfx942"
};
return
supported_archs
;
}
}
// namespace host
}
// namespace ck
host/test/gemm_multiple_d.cpp
View file @
90db83e7
...
...
@@ -7,15 +7,6 @@
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
const
std
::
string
compile_check
=
R"__ck__(
#include <${include}>
extern "C" __global__ void f() {
using type = ${template}::DeviceOp;
}
)__ck__"
;
std
::
vector
<
rtc
::
src_file
>
get_headers_for_test
()
{
std
::
vector
<
rtc
::
src_file
>
result
;
...
...
@@ -29,20 +20,41 @@ std::vector<rtc::src_file> get_headers_for_test()
return
result
;
}
TEST_CASE
(
test_operation
)
const
std
::
string
gemm_compile_check
=
R"__ck__(
#include <${include}>
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, const ck::half_t* c) {
using G = ${template};
constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${n, ${k})),
ck::make_tuple(),
ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m, ${n})));
static_assert(desc.IsValid(), "Invalid ck gemm.");
${template}::Run(desc,
a,
b,
ck::make_tuple(),
c);
}
)__ck__"
;
TEST_CASE
(
test_problem_kernel
)
{
ck
::
host
::
device_gemm_multiple_d
::
Problem
prob
;
prob
.
M
=
256
;
prob
.
N
=
256
;
prob
.
K
=
256
;
auto
ops
=
ck
::
host
::
device_gemm_multiple_d
::
Operation_Xdl_CShuffle
::
CreateOperations
(
prob
);
for
(
auto
op
:
ops
)
for
(
auto
solution
:
prob
.
GetSolutions
(
"gfx90a"
))
{
auto
solution
=
op
.
ToSolution
();
std
::
string
include
=
"ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
;
auto
src
=
ck
::
host
::
InterpolateString
(
compile_check
,
{{
"include"
,
include
},
{
"template"
,
solution
.
ToTemplateString
()}});
auto
src
=
ck
::
host
::
InterpolateString
(
compile_check
,
{{
"include"
,
prob
.
GetIncludeHeader
()},
{
"template"
,
solution
.
ToTemplateString
()},
{
"m"
,
std
::
to_string
(
prob
.
M
)},
{
"n"
,
std
::
to_string
(
prob
.
N
)},
{
"k"
,
std
::
to_string
(
prob
.
K
)}});
auto
srcs
=
get_headers_for_test
();
srcs
.
push_back
({
"main.cpp"
,
src
});
rtc
::
compile_options
options
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment