Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f52c2a4d
Commit
f52c2a4d
authored
Oct 02, 2024
by
Mirza Halilcevic
Browse files
Address PR comments.
parent
e3d444c8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
38 additions
and
198 deletions
+38
-198
codegen/test/common.hpp
codegen/test/common.hpp
+8
-40
codegen/test/gemm_multiple_d.cpp
codegen/test/gemm_multiple_d.cpp
+3
-3
codegen/test/rtc/include/rtc/compile_kernel.hpp
codegen/test/rtc/include/rtc/compile_kernel.hpp
+1
-30
codegen/test/rtc/src/compile_kernel.cpp
codegen/test/rtc/src/compile_kernel.cpp
+24
-125
include/ck/ck.hpp
include/ck/ck.hpp
+2
-0
No files found.
codegen/test/common.hpp
View file @
f52c2a4d
...
@@ -14,65 +14,33 @@
...
@@ -14,65 +14,33 @@
#include "ck/host/stringutils.hpp"
#include "ck/host/stringutils.hpp"
// NOLINTNEXTLINE
// NOLINTNEXTLINE
const
char
*
const
disable_warning_pragma
=
R"__migraphx__(
const
char
*
const
content_wrapper
=
R"__ck__(
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
${content}
${content}
#pragma clang diagnostic pop
)__ck__"
;
)__migraphx__"
;
template
<
class
P
>
template
<
class
P
>
inline
std
::
string
ck_
disable_warnings
(
P
p
)
inline
std
::
string
ck_
content_wrapper
(
P
p
)
{
{
return
ck
::
host
::
InterpolateString
(
disable_warning_pragma
,
return
ck
::
host
::
InterpolateString
(
content_wrapper
,
{{
"content"
,
std
::
string
{
p
.
data
(),
p
.
size
()}}});
{{
"content"
,
std
::
string
{
p
.
data
(),
p
.
size
()}}});
}
}
inline
std
::
vector
<
rtc
::
src_file
>
create_headers_for_
hiprtc_
test
()
inline
std
::
vector
<
rtc
::
src_file
>
create_headers_for_test
()
{
{
auto
ck_headers
=
ck
::
host
::
GetHeaders
();
auto
ck_headers
=
ck
::
host
::
GetHeaders
();
std
::
vector
<
rtc
::
src_file
>
result
;
std
::
vector
<
rtc
::
src_file
>
result
;
std
::
transform
(
ck_headers
.
begin
(),
ck_headers
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
auto
&
p
)
{
std
::
transform
(
ck_headers
.
begin
(),
ck_headers
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
auto
&
p
)
{
return
rtc
::
src_file
{
p
.
first
,
ck_
disable_warnings
(
p
.
second
)};
return
rtc
::
src_file
{
p
.
first
,
ck_
content_wrapper
(
p
.
second
)};
});
});
return
result
;
return
result
;
}
}
inline
const
std
::
vector
<
rtc
::
src_file
>&
get_headers_for_hiprtc_test
()
{
static
const
std
::
vector
<
rtc
::
src_file
>
headers
=
create_headers_for_hiprtc_test
();
return
headers
;
}
inline
std
::
vector
<
rtc
::
src_file
>
create_headers_for_clang_test
()
{
std
::
vector
<
rtc
::
src_file
>
result
;
auto
hs
=
ck
::
host
::
GetHeaders
();
std
::
transform
(
hs
.
begin
(),
hs
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
p
)
->
rtc
::
src_file
{
return
{
p
.
first
,
{
p
.
second
.
begin
(),
p
.
second
.
end
()}};
});
return
result
;
}
inline
const
std
::
vector
<
rtc
::
src_file
>&
get_headers_for_clang_test
()
{
static
const
std
::
vector
<
rtc
::
src_file
>
headers
=
create_headers_for_clang_test
();
return
headers
;
}
inline
const
std
::
vector
<
rtc
::
src_file
>&
get_headers_for_test
()
inline
const
std
::
vector
<
rtc
::
src_file
>&
get_headers_for_test
()
{
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_CODEGEN_TESTS_ENABLE_HIPRTC
)))
static
const
std
::
vector
<
rtc
::
src_file
>
headers
=
create_headers_for_test
();
{
return
headers
;
return
get_headers_for_hiprtc_test
();
}
else
{
return
get_headers_for_clang_test
();
}
}
}
template
<
typename
V
>
template
<
typename
V
>
...
...
codegen/test/gemm_multiple_d.cpp
View file @
f52c2a4d
...
@@ -71,11 +71,11 @@ TEST_CASE(test_problem_kernel)
...
@@ -71,11 +71,11 @@ TEST_CASE(test_problem_kernel)
{
"m"
,
std
::
to_string
(
prob
.
M
)},
{
"m"
,
std
::
to_string
(
prob
.
M
)},
{
"n"
,
std
::
to_string
(
prob
.
N
)},
{
"n"
,
std
::
to_string
(
prob
.
N
)},
{
"k"
,
std
::
to_string
(
prob
.
K
)}});
{
"k"
,
std
::
to_string
(
prob
.
K
)}});
auto
srcs
=
get_headers_for_test
();
srcs
.
push_back
({
"main.cpp"
,
src
});
rtc
::
compile_options
options
;
rtc
::
compile_options
options
;
options
.
kernel_name
=
"f"
;
options
.
kernel_name
=
"f"
;
options
.
additional_src_files
=
get_headers_for_test
();
auto
k
=
rtc
::
compile_kernel
(
srcs
,
options
);
auto
k
=
rtc
::
compile_kernel
(
src
,
options
);
auto
block_size
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"BlockSize"
);
auto
block_size
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"BlockSize"
);
auto
m_per_block
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"MPerBlock"
);
auto
m_per_block
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"MPerBlock"
);
auto
n_per_block
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"NPerBlock"
);
auto
n_per_block
=
solution
.
GetTemplateParameter
<
std
::
size_t
>
(
"NPerBlock"
);
...
...
codegen/test/rtc/include/rtc/compile_kernel.hpp
View file @
f52c2a4d
...
@@ -19,40 +19,11 @@ struct compile_options
...
@@ -19,40 +19,11 @@ struct compile_options
{
{
std
::
string
flags
=
""
;
std
::
string
flags
=
""
;
std
::
string
kernel_name
=
"main"
;
std
::
string
kernel_name
=
"main"
;
std
::
vector
<
src_file
>
additional_src_files
=
{};
std
::
string
params
=
""
;
};
};
struct
hip_compile_options
kernel
compile_kernel
(
const
std
::
vector
<
src_file
>&
srcs
,
{
std
::
size_t
global
;
std
::
size_t
local
;
std
::
string
kernel_name
=
"kernel"
;
std
::
string
params
=
""
;
std
::
vector
<
src_file
>
additional_src_files
=
{};
/**
* @brief Set the launch parameters but allow v to override the values
*
* @param v A value class which can have a "global" and/or "local" keys to override the default
* global and local
* @param compute_global A function used to compute the global based on the local
* @param default_local The defaul local to use if its missing from the v parameter
*/
void
set_launch_params
(
const
std
::
function
<
std
::
size_t
(
std
::
size_t
local
)
>&
compute_global
,
std
::
size_t
default_local
=
1024
);
void
set_launch_params
(
std
::
size_t
default_global
,
std
::
size_t
default_local
=
1024
)
{
set_launch_params
([
=
](
auto
)
{
return
default_global
;
},
default_local
);
}
};
kernel
compile_kernel
(
const
std
::
vector
<
src_file
>&
src
,
compile_options
options
=
compile_options
{});
compile_options
options
=
compile_options
{});
kernel
compile_kernel
(
const
std
::
string
&
content
,
compile_options
options
=
compile_options
{});
}
// namespace rtc
}
// namespace rtc
#endif
#endif
codegen/test/rtc/src/compile_kernel.cpp
View file @
f52c2a4d
...
@@ -131,32 +131,17 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str
...
@@ -131,32 +131,17 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str
}
}
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_HIPRTC(...) \
#define RTC_HIPRTC(...) hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
#define
MIGRAPHX
_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
#define
RTC
_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
template
<
class
F
,
F
f
>
// NOLINT
struct
hiprtc_program_destroy
struct
manage_deleter
{
{
template
<
class
T
>
void
operator
()(
hiprtcProgram
prog
)
const
{
hiprtcDestroyProgram
(
&
prog
);
}
void
operator
()(
T
*
x
)
const
{
if
(
x
!=
nullptr
)
{
(
void
)
f
(
x
);
}
}
};
};
template
<
class
T
,
class
F
,
F
f
>
// NOLINT
using
hiprtc_program_ptr
=
using
manage_ptr
=
std
::
unique_ptr
<
T
,
manage_deleter
<
F
,
f
>>
;
std
::
unique_ptr
<
std
::
remove_pointer_t
<
hiprtcProgram
>
,
hiprtc_program_destroy
>
;
#define MIGRAPHX_MANAGE_PTR(T, F) manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
// Workaround hiprtc's broken API
void
hiprtc_program_destroy
(
hiprtcProgram
prog
)
{
hiprtcDestroyProgram
(
&
prog
);
}
using
hiprtc_program_ptr
=
MIGRAPHX_MANAGE_PTR
(
hiprtcProgram
,
hiprtc_program_destroy
);
template
<
class
...
Ts
>
template
<
class
...
Ts
>
hiprtc_program_ptr
hiprtc_program_create
(
Ts
...
xs
)
hiprtc_program_ptr
hiprtc_program_create
(
Ts
...
xs
)
...
@@ -165,7 +150,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs)
...
@@ -165,7 +150,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs)
auto
result
=
hiprtcCreateProgram
(
&
prog
,
xs
...);
auto
result
=
hiprtcCreateProgram
(
&
prog
,
xs
...);
hiprtc_program_ptr
p
{
prog
};
hiprtc_program_ptr
p
{
prog
};
if
(
result
!=
HIPRTC_SUCCESS
)
if
(
result
!=
HIPRTC_SUCCESS
)
MIGRAPHX
_HIPRTC_THROW
(
result
,
"Create program failed."
);
RTC
_HIPRTC_THROW
(
result
,
"Create program failed."
);
return
p
;
return
p
;
}
}
...
@@ -252,11 +237,11 @@ struct hiprtc_program
...
@@ -252,11 +237,11 @@ struct hiprtc_program
std
::
string
log
()
const
std
::
string
log
()
const
{
{
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
MIGRAPHX
_HIPRTC
(
hiprtcGetProgramLogSize
(
prog
.
get
(),
&
n
));
RTC
_HIPRTC
(
hiprtcGetProgramLogSize
(
prog
.
get
(),
&
n
));
if
(
n
==
0
)
if
(
n
==
0
)
return
{};
return
{};
std
::
string
buffer
(
n
,
'\0'
);
std
::
string
buffer
(
n
,
'\0'
);
MIGRAPHX
_HIPRTC
(
hiprtcGetProgramLog
(
prog
.
get
(),
buffer
.
data
()));
RTC
_HIPRTC
(
hiprtcGetProgramLog
(
prog
.
get
(),
buffer
.
data
()));
assert
(
buffer
.
back
()
!=
0
);
assert
(
buffer
.
back
()
!=
0
);
return
buffer
;
return
buffer
;
}
}
...
@@ -264,108 +249,28 @@ struct hiprtc_program
...
@@ -264,108 +249,28 @@ struct hiprtc_program
std
::
vector
<
char
>
get_code_obj
()
const
std
::
vector
<
char
>
get_code_obj
()
const
{
{
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
MIGRAPHX
_HIPRTC
(
hiprtcGetCodeSize
(
prog
.
get
(),
&
n
));
RTC
_HIPRTC
(
hiprtcGetCodeSize
(
prog
.
get
(),
&
n
));
std
::
vector
<
char
>
buffer
(
n
);
std
::
vector
<
char
>
buffer
(
n
);
MIGRAPHX
_HIPRTC
(
hiprtcGetCode
(
prog
.
get
(),
buffer
.
data
()));
RTC
_HIPRTC
(
hiprtcGetCode
(
prog
.
get
(),
buffer
.
data
()));
return
buffer
;
return
buffer
;
}
}
};
};
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
src_file
>
srcs
,
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
const
std
::
vector
<
src_file
>&
srcs
,
const
std
::
string
&
params
,
const
compile_options
&
options
)
const
std
::
string
&
arch
)
{
{
hiprtc_program
prog
(
std
::
move
(
srcs
));
hiprtc_program
prog
(
srcs
);
auto
options
=
ck
::
host
::
SplitString
(
params
,
' '
);
auto
flags
=
ck
::
host
::
SplitString
(
options
.
flags
,
' '
);
options
.
push_back
(
"-DMIGRAPHX_USE_HIPRTC=1"
);
prog
.
compile
(
flags
);
if
(
true
)
{
options
.
push_back
(
"-DMIGRAPHX_HAS_DPP=0"
);
options
.
push_back
(
"-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"
);
options
.
push_back
(
"-Wno-reserved-identifier"
);
options
.
push_back
(
"-Wno-unused-parameter"
);
options
.
push_back
(
"-Wno-gnu-line-marker"
);
options
.
push_back
(
"-Wno-old-style-cast"
);
}
if
(
true
)
options
.
push_back
(
"-DMIGRAPHX_DEBUG"
);
if
(
std
::
none_of
(
options
.
begin
(),
options
.
end
(),
[](
const
std
::
string
&
s
)
{
return
ck
::
host
::
StartsWith
(
s
,
"--std="
)
or
ck
::
host
::
StartsWith
(
s
,
"-std="
);
}))
options
.
push_back
(
"-std=c++17"
);
options
.
push_back
(
"-fno-gpu-rdc"
);
options
.
push_back
(
"-O3"
);
options
.
push_back
(
"-Wno-cuda-compat"
);
options
.
push_back
(
"--offload-arch="
+
arch
);
prog
.
compile
(
options
);
return
{
prog
.
get_code_obj
()};
return
{
prog
.
get_code_obj
()};
}
}
bool
hip_has_flags
(
const
std
::
vector
<
std
::
string
>&
flag
s
)
static
kernel
hiprtc_compile_kernel
(
const
std
::
vector
<
src_file
>&
srcs
,
compile_options
option
s
)
{
{
hiprtc_program
prog
{
" "
};
options
.
flags
+=
" -I. -O3"
;
try
options
.
flags
+=
" -std=c++17"
;
{
options
.
flags
+=
" --offload-arch="
+
get_device_name
();
prog
.
compile
(
flags
,
true
);
auto
cos
=
compile_hip_src_with_hiprtc
(
srcs
,
options
);
return
true
;
}
catch
(...)
{
return
false
;
}
}
bool
hip_accept_non_uniform_wg
()
{
static
bool
non_uniform_wg
=
hip_has_flags
({
"-fno-offload-uniform-block"
});
return
non_uniform_wg
;
}
static
std
::
vector
<
std
::
string
>
get_compiler_warnings
()
{
std
::
vector
<
std
::
string
>
warnings
=
{
"-Weverything"
,
"-Wno-c++98-compat"
,
"-Wno-c++98-compat-pedantic"
,
"-Wno-conversion"
,
"-Wno-double-promotion"
,
"-Wno-exit-time-destructors"
,
"-Wno-extra-semi"
,
"-Wno-extra-semi-stmt"
,
"-Wno-float-conversion"
,
"-Wno-gnu-anonymous-struct"
,
"-Wno-gnu-zero-variadic-macro-arguments"
,
"-Wno-missing-prototypes"
,
"-Wno-nested-anon-types"
,
"-Wno-padded"
,
"-Wno-shorten-64-to-32"
,
"-Wno-sign-conversion"
,
"-Wno-sign-compare"
,
"-Wno-unused-command-line-argument"
,
"-Wno-weak-vtables"
,
"-Wno-c99-extensions"
,
};
if
(
hip_has_flags
({
"-Werror"
,
"-Wunsafe-buffer-usage"
}))
warnings
.
push_back
(
"-Wno-unsafe-buffer-usage"
);
return
warnings
;
}
const
std
::
vector
<
std
::
string
>&
compiler_warnings
()
{
static
std
::
vector
<
std
::
string
>
warnings
=
get_compiler_warnings
();
return
warnings
;
}
static
kernel
hiprtc_compile_kernel
(
const
std
::
string
&
content
,
compile_options
options
)
{
std
::
vector
<
src_file
>
srcs
=
options
.
additional_src_files
;
srcs
.
push_back
(
src_file
{
std
::
string
(
"main.cpp"
),
content
});
options
.
params
+=
" "
+
ck
::
host
::
JoinStrings
(
compiler_warnings
(),
" "
);
options
.
params
+=
" -ftemplate-backtrace-limit=0"
;
options
.
params
+=
" -Werror"
;
auto
cos
=
compile_hip_src_with_hiprtc
(
srcs
,
options
.
params
,
get_device_name
());
if
(
cos
.
size
()
!=
1
)
if
(
cos
.
size
()
!=
1
)
std
::
runtime_error
(
"No code object"
);
std
::
runtime_error
(
"No code object"
);
auto
&
obj
=
cos
.
front
();
auto
&
obj
=
cos
.
front
();
...
@@ -373,22 +278,16 @@ static kernel hiprtc_compile_kernel(const std::string& content, compile_options
...
@@ -373,22 +278,16 @@ static kernel hiprtc_compile_kernel(const std::string& content, compile_options
return
kernel
{
obj
.
data
(),
options
.
kernel_name
};
return
kernel
{
obj
.
data
(),
options
.
kernel_name
};
}
}
kernel
compile_kernel
(
const
std
::
string
&
content
,
compile_options
options
)
kernel
compile_kernel
(
const
std
::
vector
<
src_file
>&
srcs
,
compile_options
options
)
{
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_CODEGEN_TESTS_ENABLE_HIPRTC
)))
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_CODEGEN_TESTS_ENABLE_HIPRTC
)))
{
{
return
hiprtc_compile_kernel
(
content
,
options
);
return
hiprtc_compile_kernel
(
srcs
,
options
);
}
}
else
else
{
{
options
.
additional_src_files
.
push_back
({
"main.cpp"
,
content
});
return
clang_compile_kernel
(
srcs
,
options
);
return
clang_compile_kernel
(
options
.
additional_src_files
,
options
);
}
}
}
}
kernel
compile_kernel
(
const
std
::
vector
<
src_file
>&
src
,
compile_options
options
)
{
return
clang_compile_kernel
(
src
,
options
);
}
}
// namespace rtc
}
// namespace rtc
include/ck/ck.hpp
View file @
f52c2a4d
...
@@ -4,7 +4,9 @@
...
@@ -4,7 +4,9 @@
#pragma once
#pragma once
#include "ck/config.h"
#include "ck/config.h"
#ifndef __HIPCC_RTC__
#ifndef __HIPCC_RTC__
#include "ck/utility/env.hpp"
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment