Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
08255e1b
Commit
08255e1b
authored
Sep 18, 2024
by
Dino Musić
Browse files
Implement hiprtc for codegen tests
parent
1658c0dc
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
440 additions
and
146 deletions
+440
-146
codegen/CMakeLists.txt
codegen/CMakeLists.txt
+21
-12
codegen/include/ck/host/stringutils.hpp
codegen/include/ck/host/stringutils.hpp
+28
-0
codegen/test/common.hpp
codegen/test/common.hpp
+68
-7
codegen/test/gemm_multiple_d.cpp
codegen/test/gemm_multiple_d.cpp
+21
-125
codegen/test/rtc/include/rtc/compile_kernel.hpp
codegen/test/rtc/include/rtc/compile_kernel.hpp
+6
-1
codegen/test/rtc/include/rtc/hip.hpp
codegen/test/rtc/include/rtc/hip.hpp
+1
-0
codegen/test/rtc/include/rtc/hiprtc_enable_env.hpp
codegen/test/rtc/include/rtc/hiprtc_enable_env.hpp
+3
-0
codegen/test/rtc/src/compile_kernel.cpp
codegen/test/rtc/src/compile_kernel.cpp
+292
-1
No files found.
codegen/CMakeLists.txt
View file @
08255e1b
cmake_minimum_required
(
VERSION 3.16
)
project
(
composable_kernel_host LANGUAGES CXX HIP
)
find_package
(
ROCM
)
include
(
ROCMInstallTargets
)
include
(
ROCMTest
)
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
set
(
CMAKE_LIBRARY_OUTPUT_DIRECTORY
${
CMAKE_BINARY_DIR
}
/lib
)
set
(
CMAKE_LIBRARY_OUTPUT_DIRECTORY
${
CMAKE_BINARY_DIR
}
/lib
)
...
@@ -15,6 +21,7 @@ include_directories(BEFORE
...
@@ -15,6 +21,7 @@ include_directories(BEFORE
${
PROJECT_SOURCE_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/include
${
PROJECT_SOURCE_DIR
}
/library/include
${
PROJECT_SOURCE_DIR
}
/library/include
${
HIP_INCLUDE_DIRS
}
${
HIP_INCLUDE_DIRS
}
${
CK_ROOT
}
/include/
)
)
list
(
APPEND CMAKE_MODULE_PATH
${
CK_ROOT
}
/cmake
)
list
(
APPEND CMAKE_MODULE_PATH
${
CK_ROOT
}
/cmake
)
...
@@ -39,6 +46,8 @@ set_target_properties(ck_host PROPERTIES
...
@@ -39,6 +46,8 @@ set_target_properties(ck_host PROPERTIES
target_include_directories
(
ck_host PUBLIC
target_include_directories
(
ck_host PUBLIC
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_BINARY_DIR
}
/solution_instances>
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/ck_headers/include>
)
)
add_executable
(
ck-template-driver driver/main.cpp
)
add_executable
(
ck-template-driver driver/main.cpp
)
...
...
codegen/include/ck/host/stringutils.hpp
View file @
08255e1b
...
@@ -100,5 +100,33 @@ inline auto Transform(const Range1& r1, const Range2& r2, F f)
...
@@ -100,5 +100,33 @@ inline auto Transform(const Range1& r1, const Range2& r2, F f)
return
result
;
return
result
;
}
}
inline
bool
StartsWith
(
const
std
::
string
&
value
,
const
std
::
string
&
prefix
)
{
if
(
prefix
.
size
()
>
value
.
size
())
return
false
;
else
return
std
::
equal
(
prefix
.
begin
(),
prefix
.
end
(),
value
.
begin
());
}
inline
bool
EndsWith
(
const
std
::
string
&
value
,
const
std
::
string
&
suffix
)
{
if
(
suffix
.
size
()
>
value
.
size
())
return
false
;
else
return
std
::
equal
(
suffix
.
rbegin
(),
suffix
.
rend
(),
value
.
rbegin
());
}
inline
std
::
vector
<
std
::
string
>
SplitString
(
const
std
::
string
&
s
,
char
delim
)
{
std
::
vector
<
std
::
string
>
elems
;
std
::
stringstream
ss
(
s
+
delim
);
std
::
string
item
;
while
(
std
::
getline
(
ss
,
item
,
delim
))
{
elems
.
push_back
(
item
);
}
return
elems
;
}
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
codegen/test/common.hpp
View file @
08255e1b
...
@@ -8,18 +8,73 @@
...
@@ -8,18 +8,73 @@
#include <rtc/compile_kernel.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <rtc/hip.hpp>
#include <fstream>
#include <fstream>
#include <unordered_set>
#include "ck/host/headers.hpp"
#include "rtc/hiprtc_enable_env.hpp"
#include "ck/host/stringutils.hpp"
std
::
vector
<
rtc
::
src_file
>
get_headers_for_test
()
// NOLINTNEXTLINE
const
char
*
const
disable_warning_pragma
=
R"__migraphx__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
#pragma clang diagnostic pop
)__migraphx__"
;
template
<
class
P
>
inline
std
::
string
ck_disable_warnings
(
P
p
)
{
return
ck
::
host
::
InterpolateString
(
disable_warning_pragma
,
{{
"content"
,
std
::
string
{
p
.
data
(),
p
.
size
()}}});
}
inline
std
::
vector
<
rtc
::
src_file
>
create_headers_for_hiprtc_test
()
{
auto
ck_headers
=
ck
::
host
::
GetHeaders
();
std
::
vector
<
rtc
::
src_file
>
result
;
std
::
transform
(
ck_headers
.
begin
(),
ck_headers
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
auto
&
p
)
{
return
rtc
::
src_file
{
p
.
first
,
ck_disable_warnings
(
p
.
second
)};
});
return
result
;
}
inline
const
std
::
vector
<
rtc
::
src_file
>&
get_headers_for_hiprtc_test
()
{
static
const
std
::
vector
<
rtc
::
src_file
>
headers
=
create_headers_for_hiprtc_test
();
return
headers
;
}
inline
std
::
vector
<
rtc
::
src_file
>
create_headers_for_clang_test
()
{
{
std
::
vector
<
rtc
::
src_file
>
result
;
std
::
vector
<
rtc
::
src_file
>
result
;
auto
hs
=
ck
::
host
::
GetHeaders
();
auto
hs
=
ck
::
host
::
GetHeaders
();
std
::
transform
(
std
::
transform
(
hs
.
begin
(),
hs
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
p
)
->
rtc
::
src_file
{
hs
.
begin
(),
hs
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
p
)
->
rtc
::
src_file
{
return
{
p
.
first
,
p
.
second
};
return
{
p
.
first
,
{
p
.
second
.
begin
(),
p
.
second
.
end
()}
};
});
});
return
result
;
return
result
;
}
}
inline
const
std
::
vector
<
rtc
::
src_file
>&
get_headers_for_clang_test
()
{
static
const
std
::
vector
<
rtc
::
src_file
>
headers
=
create_headers_for_clang_test
();
return
headers
;
}
inline
const
std
::
vector
<
rtc
::
src_file
>&
get_headers_for_test
()
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_CODEGEN_TESTS_ENABLE_HIPRTC
)))
{
return
get_headers_for_hiprtc_test
();
}
else
{
return
get_headers_for_clang_test
();
}
}
template
<
typename
V
>
template
<
typename
V
>
std
::
size_t
GetSize
(
V
mLens
,
V
mStrides
)
std
::
size_t
GetSize
(
V
mLens
,
V
mStrides
)
{
{
...
@@ -34,18 +89,24 @@ std::size_t GetSize(V mLens, V mStrides)
...
@@ -34,18 +89,24 @@ std::size_t GetSize(V mLens, V mStrides)
return
space
;
return
space
;
}
}
template
<
class
T
,
typename
V
>
template
<
class
T
>
rtc
::
buffer
<
T
>
generate_buffer
(
V
mLens
,
V
mStrides
,
std
::
size_t
seed
=
0
)
rtc
::
buffer
<
T
>
generate_buffer
(
std
::
size_t
n
,
std
::
size_t
seed
=
0
)
{
{
std
::
size_t
space
=
GetSize
(
mLens
,
mStrides
);
rtc
::
buffer
<
T
>
result
(
n
);
rtc
::
buffer
<
T
>
result
(
space
);
std
::
mt19937
gen
(
seed
);
std
::
mt19937
gen
(
seed
);
std
::
uniform_real_distribution
<
double
>
dis
(
-
1.0
);
std
::
uniform_real_distribution
<
double
>
dis
(
-
1.0
);
std
::
generate
(
result
.
begin
(),
result
.
end
(),
[
&
]
{
return
dis
(
gen
);
});
std
::
generate
(
result
.
begin
(),
result
.
end
(),
[
&
]
{
return
dis
(
gen
);
});
// std::fill(result.begin(), result.end(), 1);
return
result
;
return
result
;
}
}
template
<
class
T
,
typename
V
>
std
::
enable_if_t
<!
std
::
is_integral_v
<
V
>
,
rtc
::
buffer
<
T
>>
generate_buffer
(
V
mLens
,
V
mStrides
,
std
::
size_t
seed
=
0
)
{
std
::
size_t
space
=
GetSize
(
mLens
,
mStrides
);
return
generate_buffer
<
T
>
(
space
,
seed
);
}
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
bool
allclose
(
const
T
&
a
,
const
U
&
b
,
double
atol
=
0.01
,
double
rtol
=
0.01
)
bool
allclose
(
const
T
&
a
,
const
U
&
b
,
double
atol
=
0.01
,
double
rtol
=
0.01
)
{
{
...
...
codegen/test/gemm_multiple_d.cpp
View file @
08255e1b
#include "common.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/headers.hpp"
...
@@ -15,116 +16,6 @@
...
@@ -15,116 +16,6 @@
using
half
=
_Float16
;
using
half
=
_Float16
;
// using half = __fp16;
// using half = __fp16;
std
::
vector
<
rtc
::
src_file
>
get_headers_for_test
()
{
std
::
vector
<
rtc
::
src_file
>
result
;
auto
hs
=
ck
::
host
::
GetHeaders
();
std
::
transform
(
hs
.
begin
(),
hs
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
p
)
->
rtc
::
src_file
{
return
{
p
.
first
,
p
.
second
};
});
return
result
;
}
template
<
class
T
>
rtc
::
buffer
<
T
>
generate_buffer
(
std
::
size_t
n
,
std
::
size_t
seed
=
0
)
{
rtc
::
buffer
<
T
>
result
(
n
);
std
::
mt19937
gen
(
seed
);
std
::
uniform_real_distribution
<
double
>
dis
(
-
1.0
);
std
::
generate
(
result
.
begin
(),
result
.
end
(),
[
&
]
{
return
dis
(
gen
);
});
return
result
;
}
template
<
class
T
,
class
U
>
bool
allclose
(
const
T
&
a
,
const
U
&
b
,
double
atol
=
0.01
,
double
rtol
=
0.01
)
{
return
std
::
equal
(
a
.
begin
(),
a
.
end
(),
b
.
begin
(),
b
.
end
(),
[
&
](
double
x
,
double
y
)
{
return
fabs
(
x
-
y
)
<
atol
+
rtol
*
fabs
(
y
);
});
}
std
::
string
classify
(
double
x
)
{
switch
(
std
::
fpclassify
(
x
))
{
case
FP_INFINITE
:
return
"inf"
;
case
FP_NAN
:
return
"nan"
;
case
FP_NORMAL
:
return
"normal"
;
case
FP_SUBNORMAL
:
return
"subnormal"
;
case
FP_ZERO
:
return
"zero"
;
default:
return
"unknown"
;
}
}
template
<
class
Buffer
>
void
print_classification
(
const
Buffer
&
x
)
{
std
::
unordered_set
<
std
::
string
>
result
;
for
(
const
auto
&
i
:
x
)
result
.
insert
(
classify
(
i
));
for
(
const
auto
&
c
:
result
)
std
::
cout
<<
c
<<
", "
;
std
::
cout
<<
std
::
endl
;
}
template
<
class
Buffer
>
void
print_statistics
(
const
Buffer
&
x
)
{
std
::
cout
<<
"Min value: "
<<
*
std
::
min_element
(
x
.
begin
(),
x
.
end
())
<<
", "
;
std
::
cout
<<
"Max value: "
<<
*
std
::
max_element
(
x
.
begin
(),
x
.
end
())
<<
", "
;
double
num_elements
=
x
.
size
();
auto
mean
=
std
::
accumulate
(
x
.
begin
(),
x
.
end
(),
double
{
0.0
},
std
::
plus
<
double
>
{})
/
num_elements
;
auto
stddev
=
std
::
sqrt
(
std
::
accumulate
(
x
.
begin
(),
x
.
end
(),
double
{
0.0
},
[
&
](
double
r
,
double
v
)
{
return
r
+
std
::
pow
((
v
-
mean
),
2.0
);
})
/
num_elements
);
std
::
cout
<<
"Mean: "
<<
mean
<<
", "
;
std
::
cout
<<
"StdDev: "
<<
stddev
<<
"
\n
"
;
}
template
<
class
Buffer
>
void
print_preview
(
const
Buffer
&
x
)
{
if
(
x
.
size
()
<=
10
)
{
std
::
for_each
(
x
.
begin
(),
x
.
end
(),
[
&
](
double
i
)
{
std
::
cout
<<
i
<<
", "
;
});
}
else
{
std
::
for_each
(
x
.
begin
(),
x
.
begin
()
+
5
,
[
&
](
double
i
)
{
std
::
cout
<<
i
<<
", "
;
});
std
::
cout
<<
"..., "
;
std
::
for_each
(
x
.
end
()
-
5
,
x
.
end
(),
[
&
](
double
i
)
{
std
::
cout
<<
i
<<
", "
;
});
}
std
::
cout
<<
std
::
endl
;
}
template
<
class
T
>
struct
check_all
{
rtc
::
buffer
<
T
>
data
{};
bool
operator
()(
const
rtc
::
buffer
<
T
>&
x
)
{
if
(
data
.
empty
())
{
data
=
x
;
return
true
;
}
if
(
std
::
any_of
(
x
.
begin
(),
x
.
end
(),
[](
double
y
)
{
return
std
::
isnan
(
y
);
}))
return
false
;
return
allclose
(
data
,
x
);
}
};
template
<
class
Solution
>
auto
report
(
const
Solution
&
solution
,
bool
pass
)
{
return
test
::
make_predicate
(
solution
.
ToTemplateString
(),
[
=
]
{
return
pass
;
});
}
const
std
::
string
gemm_compile_check
=
R"__ck__(
const
std
::
string
gemm_compile_check
=
R"__ck__(
#include <${include}>
#include <${include}>
...
@@ -163,19 +54,24 @@ TEST_CASE(test_problem_kernel)
...
@@ -163,19 +54,24 @@ TEST_CASE(test_problem_kernel)
std
::
string
epilogue
=
""
;
std
::
string
epilogue
=
""
;
std
::
string
prologue
=
""
;
std
::
string
prologue
=
""
;
for
(
auto
solution
:
prob
.
GetSolutions
(
"gfx90a"
,
prologue
,
epilogue
))
auto
solutions
=
prob
.
GetSolutions
(
"gfx90a"
,
prologue
,
epilogue
);
std
::
cout
<<
"Num solutions: "
<<
solutions
.
size
()
<<
std
::
endl
;
for
(
auto
i
=
0
;
i
<
solutions
.
size
();
++
i
)
{
{
std
::
cout
<<
"Testing solution "
<<
std
::
to_string
(
i
+
1
)
<<
std
::
endl
;
auto
&&
solution
=
solutions
[
i
];
auto
src
=
ck
::
host
::
InterpolateString
(
gemm_compile_check
,
auto
src
=
ck
::
host
::
InterpolateString
(
gemm_compile_check
,
{{
"include"
,
prob
.
GetIncludeHeader
()},
{{
"include"
,
prob
.
GetIncludeHeader
()},
{
"template"
,
solution
.
ToTemplateString
()},
{
"template"
,
solution
.
ToTemplateString
()},
{
"m"
,
std
::
to_string
(
prob
.
M
)},
{
"m"
,
std
::
to_string
(
prob
.
M
)},
{
"n"
,
std
::
to_string
(
prob
.
N
)},
{
"n"
,
std
::
to_string
(
prob
.
N
)},
{
"k"
,
std
::
to_string
(
prob
.
K
)}});
{
"k"
,
std
::
to_string
(
prob
.
K
)}});
auto
srcs
=
get_headers_for_test
();
srcs
.
push_back
({
"main.cpp"
,
src
});
rtc
::
compile_options
options
;
rtc
::
compile_options
options
;
options
.
kernel_name
=
"f"
;
options
.
kernel_name
=
"f"
;
auto
k
=
rtc
::
compile_kernel
(
srcs
,
options
);
options
.
additional_src_files
=
get_headers_for_test
();
auto
k
=
rtc
::
compile_kernel
(
src
,
options
);
auto
block_size
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"BlockSize"
);
auto
block_size
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"BlockSize"
);
auto
m_per_block
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"MPerBlock"
);
auto
m_per_block
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"MPerBlock"
);
auto
n_per_block
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"NPerBlock"
);
auto
n_per_block
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"NPerBlock"
);
...
...
codegen/test/rtc/include/rtc/compile_kernel.hpp
View file @
08255e1b
...
@@ -4,24 +4,29 @@
...
@@ -4,24 +4,29 @@
#include <rtc/kernel.hpp>
#include <rtc/kernel.hpp>
#include <ck/filesystem.hpp>
#include <ck/filesystem.hpp>
#include <string>
#include <string>
#include <functional>
namespace
rtc
{
namespace
rtc
{
struct
src_file
struct
src_file
{
{
CK
::
fs
::
path
path
;
CK
::
fs
::
path
path
;
std
::
string
_view
content
;
std
::
string
content
;
};
};
struct
compile_options
struct
compile_options
{
{
std
::
string
flags
=
""
;
std
::
string
flags
=
""
;
std
::
string
kernel_name
=
"main"
;
std
::
string
kernel_name
=
"main"
;
std
::
vector
<
src_file
>
additional_src_files
=
{};
std
::
string
params
=
""
;
};
};
kernel
compile_kernel
(
const
std
::
vector
<
src_file
>&
src
,
kernel
compile_kernel
(
const
std
::
vector
<
src_file
>&
src
,
compile_options
options
=
compile_options
{});
compile_options
options
=
compile_options
{});
kernel
compile_kernel
(
const
std
::
string
&
content
,
compile_options
options
=
compile_options
{});
}
// namespace rtc
}
// namespace rtc
#endif
#endif
codegen/test/rtc/include/rtc/hip.hpp
View file @
08255e1b
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime_api.h>
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <stdexcept>
namespace
rtc
{
namespace
rtc
{
...
...
codegen/test/rtc/include/rtc/hiprtc_enable_env.hpp
0 → 100644
View file @
08255e1b
#include <ck/utility/env.hpp>
CK_DECLARE_ENV_VAR_BOOL
(
CK_CODEGEN_TESTS_ENABLE_HIPRTC
)
\ No newline at end of file
codegen/test/rtc/src/compile_kernel.cpp
View file @
08255e1b
#include "rtc/hip.hpp"
#include "rtc/hip.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/compile_kernel.hpp>
// TODO include only if USE_RTC is set?
#include <hip/hiprtc.h>
#include <rtc/tmp_dir.hpp>
#include <rtc/tmp_dir.hpp>
#include <stdexcept>
#include <stdexcept>
#include <iostream>
#include <iostream>
#include <fstream>
#include <fstream>
#include <cassert>
#include <cassert>
#include <numeric>
#include <deque>
#include <rtc/hiprtc_enable_env.hpp>
#include <ck/host/stringutils.hpp>
namespace
rtc
{
namespace
rtc
{
...
@@ -59,7 +65,7 @@ std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device
...
@@ -59,7 +65,7 @@ std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device
// TODO: undo after extracting the codeobj
// TODO: undo after extracting the codeobj
// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; }
// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; }
kernel
compile_kernel
(
const
std
::
vector
<
src_file
>&
srcs
,
compile_options
options
)
kernel
clang_
compile_kernel
(
const
std
::
vector
<
src_file
>&
srcs
,
compile_options
options
)
{
{
assert
(
not
srcs
.
empty
());
assert
(
not
srcs
.
empty
());
tmp_dir
td
{
"compile"
};
tmp_dir
td
{
"compile"
};
...
@@ -100,4 +106,289 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
...
@@ -100,4 +106,289 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
return
kernel
{
obj
.
data
(),
options
.
kernel_name
};
return
kernel
{
obj
.
data
(),
options
.
kernel_name
};
}
}
struct
hiprtc_src_file
{
hiprtc_src_file
()
=
default
;
hiprtc_src_file
(
const
src_file
&
s
)
:
path
(
s
.
path
.
string
()),
content
(
s
.
content
)
{}
std
::
string
path
;
std
::
string
content
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
path
,
"path"
),
f
(
self
.
content
,
"content"
));
}
};
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
{
return
"hiprtc: "
+
(
hiprtcGetErrorString
(
err
)
+
(
": "
+
msg
));
}
void
hiprtc_check_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
,
const
std
::
string
&
ctx
)
{
if
(
err
!=
HIPRTC_SUCCESS
)
throw
std
::
runtime_error
(
hiprtc_error
(
err
,
msg
));
}
// NOLINTNEXTLINE
#define MIGRAPHX_HIPRTC(...) \
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
#define MIGRAPHX_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
template
<
class
F
,
F
f
>
// NOLINT
struct
manage_deleter
{
template
<
class
T
>
void
operator
()(
T
*
x
)
const
{
if
(
x
!=
nullptr
)
{
(
void
)
f
(
x
);
}
}
};
template
<
class
T
,
class
F
,
F
f
>
// NOLINT
using
manage_ptr
=
std
::
unique_ptr
<
T
,
manage_deleter
<
F
,
f
>>
;
#define MIGRAPHX_MANAGE_PTR(T, F) manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
// Workaround hiprtc's broken API
void
hiprtc_program_destroy
(
hiprtcProgram
prog
)
{
hiprtcDestroyProgram
(
&
prog
);
}
using
hiprtc_program_ptr
=
MIGRAPHX_MANAGE_PTR
(
hiprtcProgram
,
hiprtc_program_destroy
);
template
<
class
...
Ts
>
hiprtc_program_ptr
hiprtc_program_create
(
Ts
...
xs
)
{
hiprtcProgram
prog
=
nullptr
;
auto
result
=
hiprtcCreateProgram
(
&
prog
,
xs
...);
hiprtc_program_ptr
p
{
prog
};
if
(
result
!=
HIPRTC_SUCCESS
)
MIGRAPHX_HIPRTC_THROW
(
result
,
"Create program failed."
);
return
p
;
}
struct
hiprtc_program
{
struct
string_array
{
std
::
deque
<
std
::
string
>
strings
{};
std
::
vector
<
const
char
*>
c_strs
{};
string_array
()
{}
string_array
(
const
string_array
&
)
=
delete
;
std
::
size_t
size
()
const
{
return
strings
.
size
();
}
const
char
**
data
()
{
return
c_strs
.
data
();
}
void
push_back
(
std
::
string
s
)
{
strings
.
push_back
(
std
::
move
(
s
));
c_strs
.
push_back
(
strings
.
back
().
c_str
());
}
};
hiprtc_program_ptr
prog
=
nullptr
;
string_array
headers
{};
string_array
include_names
{};
std
::
string
cpp_src
=
""
;
std
::
string
cpp_name
=
""
;
hiprtc_program
(
const
std
::
string
&
src
,
const
std
::
string
&
name
=
"main.cpp"
)
:
cpp_src
(
src
),
cpp_name
(
name
)
{
create_program
();
}
hiprtc_program
(
std
::
vector
<
src_file
>
srcs
)
{
for
(
auto
&&
src
:
srcs
)
{
if
(
ck
::
host
::
EndsWith
(
src
.
path
,
".cpp"
))
{
cpp_src
=
std
::
move
(
src
.
content
);
cpp_name
=
std
::
move
(
src
.
path
);
}
else
{
headers
.
push_back
(
std
::
string
(
src
.
content
.
begin
(),
src
.
content
.
end
()));
include_names
.
push_back
(
std
::
move
(
src
.
path
));
}
}
create_program
();
}
void
create_program
()
{
assert
(
not
cpp_src
.
empty
());
assert
(
not
cpp_name
.
empty
());
assert
(
headers
.
size
()
==
include_names
.
size
());
prog
=
hiprtc_program_create
(
cpp_src
.
c_str
(),
cpp_name
.
c_str
(),
headers
.
size
(),
headers
.
data
(),
include_names
.
data
());
}
void
compile
(
const
std
::
vector
<
std
::
string
>&
options
,
bool
quiet
=
false
)
const
{
std
::
vector
<
const
char
*>
c_options
;
std
::
transform
(
options
.
begin
(),
options
.
end
(),
std
::
back_inserter
(
c_options
),
[](
const
std
::
string
&
s
)
{
return
s
.
c_str
();
});
auto
result
=
hiprtcCompileProgram
(
prog
.
get
(),
c_options
.
size
(),
c_options
.
data
());
auto
prog_log
=
log
();
if
(
not
prog_log
.
empty
()
and
not
quiet
)
{
std
::
cerr
<<
prog_log
<<
std
::
endl
;
}
if
(
result
!=
HIPRTC_SUCCESS
)
throw
std
::
runtime_error
(
"Compilation failed."
);
}
std
::
string
log
()
const
{
std
::
size_t
n
=
0
;
MIGRAPHX_HIPRTC
(
hiprtcGetProgramLogSize
(
prog
.
get
(),
&
n
));
if
(
n
==
0
)
return
{};
std
::
string
buffer
(
n
,
'\0'
);
MIGRAPHX_HIPRTC
(
hiprtcGetProgramLog
(
prog
.
get
(),
buffer
.
data
()));
assert
(
buffer
.
back
()
!=
0
);
return
buffer
;
}
std
::
vector
<
char
>
get_code_obj
()
const
{
std
::
size_t
n
=
0
;
MIGRAPHX_HIPRTC
(
hiprtcGetCodeSize
(
prog
.
get
(),
&
n
));
std
::
vector
<
char
>
buffer
(
n
);
MIGRAPHX_HIPRTC
(
hiprtcGetCode
(
prog
.
get
(),
buffer
.
data
()));
return
buffer
;
}
};
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
src_file
>
srcs
,
const
std
::
string
&
params
,
const
std
::
string
&
arch
)
{
hiprtc_program
prog
(
std
::
move
(
srcs
));
auto
options
=
ck
::
host
::
SplitString
(
params
,
' '
);
options
.
push_back
(
"-DMIGRAPHX_USE_HIPRTC=1"
);
if
(
true
)
{
options
.
push_back
(
"-DMIGRAPHX_HAS_DPP=0"
);
options
.
push_back
(
"-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"
);
options
.
push_back
(
"-Wno-reserved-identifier"
);
options
.
push_back
(
"-Wno-unused-parameter"
);
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
}
if
(
true
)
options
.
push_back
(
"-DMIGRAPHX_DEBUG"
);
if
(
std
::
none_of
(
options
.
begin
(),
options
.
end
(),
[](
const
std
::
string
&
s
)
{
return
ck
::
host
::
StartsWith
(
s
,
"--std="
)
or
ck
::
host
::
StartsWith
(
s
,
"-std="
);
}))
options
.
push_back
(
"-std=c++17"
);
options
.
push_back
(
"-fno-gpu-rdc"
);
options
.
push_back
(
"-O3"
);
options
.
push_back
(
"-Wno-cuda-compat"
);
options
.
push_back
(
"--offload-arch="
+
arch
);
prog
.
compile
(
options
);
return
{
prog
.
get_code_obj
()};
}
bool
hip_has_flags
(
const
std
::
vector
<
std
::
string
>&
flags
)
{
hiprtc_program
prog
{
" "
};
try
{
prog
.
compile
(
flags
,
true
);
return
true
;
}
catch
(...)
{
return
false
;
}
}
bool
hip_accept_non_uniform_wg
()
{
static
bool
non_uniform_wg
=
hip_has_flags
({
"-fno-offload-uniform-block"
});
return
non_uniform_wg
;
}
static
std
::
vector
<
std
::
string
>
get_compiler_warnings
()
{
std
::
vector
<
std
::
string
>
warnings
=
{
"-Weverything"
,
"-Wno-c++98-compat"
,
"-Wno-c++98-compat-pedantic"
,
"-Wno-conversion"
,
"-Wno-double-promotion"
,
"-Wno-exit-time-destructors"
,
"-Wno-extra-semi"
,
"-Wno-extra-semi-stmt"
,
"-Wno-float-conversion"
,
"-Wno-gnu-anonymous-struct"
,
"-Wno-gnu-zero-variadic-macro-arguments"
,
"-Wno-missing-prototypes"
,
"-Wno-nested-anon-types"
,
"-Wno-padded"
,
"-Wno-shorten-64-to-32"
,
"-Wno-sign-conversion"
,
"-Wno-sign-compare"
,
"-Wno-unused-command-line-argument"
,
"-Wno-weak-vtables"
,
"-Wno-c99-extensions"
,
};
if
(
hip_has_flags
({
"-Werror"
,
"-Wunsafe-buffer-usage"
}))
warnings
.
push_back
(
"-Wno-unsafe-buffer-usage"
);
return
warnings
;
}
const
std
::
vector
<
std
::
string
>&
compiler_warnings
()
{
static
std
::
vector
<
std
::
string
>
warnings
=
get_compiler_warnings
();
return
warnings
;
}
static
kernel
hiprtc_compile_kernel
(
const
std
::
string
&
content
,
compile_options
options
)
{
std
::
vector
<
src_file
>
srcs
=
options
.
additional_src_files
;
srcs
.
push_back
(
src_file
{
std
::
string
(
"main.cpp"
),
content
});
options
.
params
+=
" "
+
ck
::
host
::
JoinStrings
(
compiler_warnings
(),
" "
);
options
.
params
+=
" -ftemplate-backtrace-limit=0"
;
options
.
params
+=
" -Werror"
;
auto
cos
=
compile_hip_src_with_hiprtc
(
srcs
,
options
.
params
,
get_device_name
());
if
(
cos
.
size
()
!=
1
)
std
::
runtime_error
(
"No code object"
);
auto
&
obj
=
cos
.
front
();
return
kernel
{
obj
.
data
(),
options
.
kernel_name
};
}
kernel
compile_kernel
(
const
std
::
string
&
content
,
compile_options
options
)
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_CODEGEN_TESTS_ENABLE_HIPRTC
)))
{
return
hiprtc_compile_kernel
(
content
,
options
);
}
else
{
options
.
additional_src_files
.
push_back
({
"main.cpp"
,
content
});
return
clang_compile_kernel
(
options
.
additional_src_files
,
options
);
}
}
kernel
compile_kernel
(
const
std
::
vector
<
src_file
>&
src
,
compile_options
options
)
{
return
clang_compile_kernel
(
src
,
options
);
}
}
// namespace rtc
}
// namespace rtc
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment