Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3b2a7aee
"test/git@developer.sourcefind.cn:change/sglang.git" did not exist on "083629c23564e1a64deaa052f1df5c5d914358d8"
Commit
3b2a7aee
authored
Oct 09, 2024
by
Mirza Halilcevic
Browse files
Address PR comments.
parent
d3a96e51
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
86 additions
and
86 deletions
+86
-86
codegen/include/ck/host/stringutils.hpp
codegen/include/ck/host/stringutils.hpp
+0
-28
codegen/test/CMakeLists.txt
codegen/test/CMakeLists.txt
+0
-7
codegen/test/rtc/CMakeLists.txt
codegen/test/rtc/CMakeLists.txt
+6
-0
codegen/test/rtc/src/compile_kernel.cpp
codegen/test/rtc/src/compile_kernel.cpp
+36
-32
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+44
-19
No files found.
codegen/include/ck/host/stringutils.hpp
View file @
3b2a7aee
...
@@ -100,33 +100,5 @@ inline auto Transform(const Range1& r1, const Range2& r2, F f)
...
@@ -100,33 +100,5 @@ inline auto Transform(const Range1& r1, const Range2& r2, F f)
return
result
;
return
result
;
}
}
inline
bool
StartsWith
(
const
std
::
string
&
value
,
const
std
::
string
&
prefix
)
{
if
(
prefix
.
size
()
>
value
.
size
())
return
false
;
else
return
std
::
equal
(
prefix
.
begin
(),
prefix
.
end
(),
value
.
begin
());
}
inline
bool
EndsWith
(
const
std
::
string
&
value
,
const
std
::
string
&
suffix
)
{
if
(
suffix
.
size
()
>
value
.
size
())
return
false
;
else
return
std
::
equal
(
suffix
.
rbegin
(),
suffix
.
rend
(),
value
.
rbegin
());
}
inline
std
::
vector
<
std
::
string
>
SplitString
(
const
std
::
string
&
s
,
char
delim
)
{
std
::
vector
<
std
::
string
>
elems
;
std
::
stringstream
ss
(
s
+
delim
);
std
::
string
item
;
while
(
std
::
getline
(
ss
,
item
,
delim
))
{
elems
.
push_back
(
item
);
}
return
elems
;
}
}
// namespace host
}
// namespace host
}
// namespace ck
}
// namespace ck
codegen/test/CMakeLists.txt
View file @
3b2a7aee
option
(
USE_HIPRTC_FOR_CODEGEN_TESTS
"Whether to enable hipRTC for codegen tests."
ON
)
if
(
USE_HIPRTC_FOR_CODEGEN_TESTS
)
add_compile_definitions
(
HIPRTC_FOR_CODEGEN_TESTS
)
message
(
"CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to
${
USE_HIPRTC_FOR_CODEGEN_TESTS
}
"
)
endif
()
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm
)
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm
)
add_subdirectory
(
rtc
)
add_subdirectory
(
rtc
)
file
(
GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp
)
file
(
GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp
)
...
...
codegen/test/rtc/CMakeLists.txt
View file @
3b2a7aee
...
@@ -2,3 +2,9 @@ file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp)
...
@@ -2,3 +2,9 @@ file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp)
add_library
(
ck_rtc
${
RTC_SOURCES
}
)
add_library
(
ck_rtc
${
RTC_SOURCES
}
)
target_include_directories
(
ck_rtc PUBLIC include
)
target_include_directories
(
ck_rtc PUBLIC include
)
target_link_libraries
(
ck_rtc PUBLIC hip::host
)
target_link_libraries
(
ck_rtc PUBLIC hip::host
)
option
(
USE_HIPRTC_FOR_CODEGEN_TESTS
"Whether to enable hipRTC for codegen tests."
ON
)
if
(
USE_HIPRTC_FOR_CODEGEN_TESTS
)
target_compile_definitions
(
ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS
)
message
(
"CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to
${
USE_HIPRTC_FOR_CODEGEN_TESTS
}
"
)
endif
()
codegen/test/rtc/src/compile_kernel.cpp
View file @
3b2a7aee
#include <ck/host/stringutils.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <rtc/hip.hpp>
#ifdef HIPRTC_FOR_CODEGEN_TESTS
#ifdef HIPRTC_FOR_CODEGEN_TESTS
#include <hip/hiprtc.h>
#include <hip/hiprtc.h>
#include <rtc/manage_ptr.hpp>
#endif
#endif
#include <rtc/tmp_dir.hpp>
#include <rtc/tmp_dir.hpp>
#include <cassert>
#include <cassert>
...
@@ -14,6 +14,26 @@
...
@@ -14,6 +14,26 @@
namespace
rtc
{
namespace
rtc
{
bool
EndsWith
(
const
std
::
string
&
value
,
const
std
::
string
&
suffix
)
{
if
(
suffix
.
size
()
>
value
.
size
())
return
false
;
else
return
std
::
equal
(
suffix
.
rbegin
(),
suffix
.
rend
(),
value
.
rbegin
());
}
std
::
vector
<
std
::
string
>
SplitString
(
const
std
::
string
&
s
,
char
delim
)
{
std
::
vector
<
std
::
string
>
elems
;
std
::
stringstream
ss
(
s
+
delim
);
std
::
string
item
;
while
(
std
::
getline
(
ss
,
item
,
delim
))
{
elems
.
push_back
(
item
);
}
return
elems
;
}
template
<
class
T
>
template
<
class
T
>
T
generic_read_file
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
)
T
generic_read_file
(
const
std
::
string
&
filename
,
size_t
offset
=
0
,
size_t
nbytes
=
0
)
{
{
...
@@ -108,42 +128,27 @@ kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options o
...
@@ -108,42 +128,27 @@ kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options o
#ifdef HIPRTC_FOR_CODEGEN_TESTS
#ifdef HIPRTC_FOR_CODEGEN_TESTS
struct
hiprtc_src_file
{
hiprtc_src_file
()
=
default
;
hiprtc_src_file
(
const
src_file
&
s
)
:
path
(
s
.
path
.
string
()),
content
(
s
.
content
)
{}
std
::
string
path
;
std
::
string
content
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
path
,
"path"
),
f
(
self
.
content
,
"content"
));
}
};
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
std
::
string
hiprtc_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
)
{
{
return
"hiprtc: "
+
(
hiprtcGetErrorString
(
err
)
+
(
": "
+
msg
));
return
"hiprtc: "
+
(
hiprtcGetErrorString
(
err
)
+
(
": "
+
msg
));
}
}
void
hiprtc_check_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
,
const
std
::
string
&
ctx
)
void
hiprtc_check_error
(
hiprtcResult
err
,
const
std
::
string
&
msg
=
""
)
{
{
if
(
err
!=
HIPRTC_SUCCESS
)
if
(
err
!=
HIPRTC_SUCCESS
)
throw
std
::
runtime_error
(
hiprtc_error
(
err
,
msg
));
throw
std
::
runtime_error
(
hiprtc_error
(
err
,
msg
));
}
}
// NOLINTNEXTLINE
struct
hiprtc_src_file
#define RTC_HIPRTC(...) hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
#define RTC_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
struct
hiprtc_program_destroy
{
{
void
operator
()(
hiprtcProgram
prog
)
const
{
hiprtcDestroyProgram
(
&
prog
);
}
hiprtc_src_file
()
=
default
;
hiprtc_src_file
(
const
src_file
&
s
)
:
path
(
s
.
path
.
string
()),
content
(
s
.
content
)
{}
std
::
string
path
;
std
::
string
content
;
};
};
using
hiprtc_program_
ptr
=
void
hiprtc_program_
destroy
(
hiprtcProgram
prog
)
{
hiprtcDestroyProgram
(
&
prog
);
}
std
::
unique_ptr
<
std
::
remove_pointer_t
<
hiprtcProgram
>
,
hiprtc_program_destroy
>
;
using
hiprtc_program_ptr
=
RTC_MANAGE_PTR
(
hiprtcProgram
,
hiprtc_program_destroy
)
;
template
<
class
...
Ts
>
template
<
class
...
Ts
>
hiprtc_program_ptr
hiprtc_program_create
(
Ts
...
xs
)
hiprtc_program_ptr
hiprtc_program_create
(
Ts
...
xs
)
...
@@ -151,8 +156,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs)
...
@@ -151,8 +156,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs)
hiprtcProgram
prog
=
nullptr
;
hiprtcProgram
prog
=
nullptr
;
auto
result
=
hiprtcCreateProgram
(
&
prog
,
xs
...);
auto
result
=
hiprtcCreateProgram
(
&
prog
,
xs
...);
hiprtc_program_ptr
p
{
prog
};
hiprtc_program_ptr
p
{
prog
};
if
(
result
!=
HIPRTC_SUCCESS
)
hiprtc_check_error
(
result
,
"Create program failed."
);
RTC_HIPRTC_THROW
(
result
,
"Create program failed."
);
return
p
;
return
p
;
}
}
...
@@ -193,7 +197,7 @@ struct hiprtc_program
...
@@ -193,7 +197,7 @@ struct hiprtc_program
{
{
for
(
auto
&&
src
:
srcs
)
for
(
auto
&&
src
:
srcs
)
{
{
if
(
ck
::
host
::
EndsWith
(
src
.
path
,
".cpp"
))
if
(
EndsWith
(
src
.
path
,
".cpp"
))
{
{
cpp_src
=
std
::
move
(
src
.
content
);
cpp_src
=
std
::
move
(
src
.
content
);
cpp_name
=
std
::
move
(
src
.
path
);
cpp_name
=
std
::
move
(
src
.
path
);
...
@@ -239,11 +243,11 @@ struct hiprtc_program
...
@@ -239,11 +243,11 @@ struct hiprtc_program
std
::
string
log
()
const
std
::
string
log
()
const
{
{
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
RTC_HIPRTC
(
hiprtcGetProgramLogSize
(
prog
.
get
(),
&
n
));
hiprtc_check_error
(
hiprtcGetProgramLogSize
(
prog
.
get
(),
&
n
));
if
(
n
==
0
)
if
(
n
==
0
)
return
{};
return
{};
std
::
string
buffer
(
n
,
'\0'
);
std
::
string
buffer
(
n
,
'\0'
);
RTC_HIPRTC
(
hiprtcGetProgramLog
(
prog
.
get
(),
buffer
.
data
()));
hiprtc_check_error
(
hiprtcGetProgramLog
(
prog
.
get
(),
buffer
.
data
()));
assert
(
buffer
.
back
()
!=
0
);
assert
(
buffer
.
back
()
!=
0
);
return
buffer
;
return
buffer
;
}
}
...
@@ -251,9 +255,9 @@ struct hiprtc_program
...
@@ -251,9 +255,9 @@ struct hiprtc_program
std
::
vector
<
char
>
get_code_obj
()
const
std
::
vector
<
char
>
get_code_obj
()
const
{
{
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
RTC_HIPRTC
(
hiprtcGetCodeSize
(
prog
.
get
(),
&
n
));
hiprtc_check_error
(
hiprtcGetCodeSize
(
prog
.
get
(),
&
n
));
std
::
vector
<
char
>
buffer
(
n
);
std
::
vector
<
char
>
buffer
(
n
);
RTC_HIPRTC
(
hiprtcGetCode
(
prog
.
get
(),
buffer
.
data
()));
hiprtc_check_error
(
hiprtcGetCode
(
prog
.
get
(),
buffer
.
data
()));
return
buffer
;
return
buffer
;
}
}
};
};
...
@@ -262,7 +266,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(const std::vector<src
...
@@ -262,7 +266,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(const std::vector<src
const
compile_options
&
options
)
const
compile_options
&
options
)
{
{
hiprtc_program
prog
(
srcs
);
hiprtc_program
prog
(
srcs
);
auto
flags
=
ck
::
host
::
SplitString
(
options
.
flags
,
' '
);
auto
flags
=
SplitString
(
options
.
flags
,
' '
);
prog
.
compile
(
flags
);
prog
.
compile
(
flags
);
return
{
prog
.
get_code_obj
()};
return
{
prog
.
get_code_obj
()};
}
}
...
...
include/ck/utility/type.hpp
View file @
3b2a7aee
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
namespace
ck
{
namespace
ck
{
#ifdef __HIPCC_RTC__
#ifdef __HIPCC_RTC__
template
<
bool
B
>
template
<
bool
B
>
using
bool_constant
=
integral_constant
<
bool
,
B
>
;
using
bool_constant
=
integral_constant
<
bool
,
B
>
;
...
@@ -113,51 +114,75 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
...
@@ -113,51 +114,75 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
return
static_cast
<
T
&&>
(
t_
);
return
static_cast
<
T
&&>
(
t_
);
}
}
template
<
class
T
>
struct
is_const
:
false_type
{};
template
<
class
T
>
template
<
class
T
>
struct
is_const
<
const
T
>
:
true_type
{};
struct
is_const
:
false_type
template
<
class
T
>
{
};
template
<
class
T
>
struct
is_const
<
const
T
>
:
true_type
{
};
template
<
class
T
>
inline
constexpr
bool
is_const_v
=
is_const
<
T
>::
value
;
inline
constexpr
bool
is_const_v
=
is_const
<
T
>::
value
;
template
<
class
T
>
template
<
class
T
>
inline
constexpr
bool
is_reference_v
=
is_reference
<
T
>::
value
;
inline
constexpr
bool
is_reference_v
=
is_reference
<
T
>::
value
;
template
<
class
T
>
struct
remove_const
{
typedef
T
type
;
};
template
<
class
T
>
template
<
class
T
>
struct
remove_const
<
const
T
>
{
typedef
T
type
;
};
struct
remove_const
template
<
class
T
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_const
<
const
T
>
{
typedef
T
type
;
};
template
<
class
T
>
using
remove_const_t
=
typename
remove_const
<
T
>::
type
;
using
remove_const_t
=
typename
remove_const
<
T
>::
type
;
template
<
class
T
>
template
<
class
T
>
inline
constexpr
bool
is_class_v
=
is_class
<
T
>::
value
;
inline
constexpr
bool
is_class_v
=
is_class
<
T
>::
value
;
template
<
class
T
>
template
<
class
T
>
inline
constexpr
bool
is_trivially_copyable_v
=
is_trivially_copyable
<
T
>::
value
;
inline
constexpr
bool
is_trivially_copyable_v
=
is_trivially_copyable
<
T
>::
value
;
template
<
class
...
>
template
<
class
...
>
using
void_t
=
void
;
using
void_t
=
void
;
using
__hip
::
declval
;
template
<
class
T
,
class
U
=
T
&&
>
U
private_declval
(
int
);
template
<
class
T
>
T
private_declval
(
long
);
template
<
class
T
>
auto
declval
()
noexcept
->
decltype
(
private_declval
<
T
>
(
0
));
#else
#else
#include <utility>
#include <utility>
#include <type_traits>
#include <type_traits>
using
std
::
declval
;
using
std
::
false_type
;
using
std
::
forward
;
using
std
::
forward
;
using
std
::
is_base_of
;
using
std
::
is_base_of
;
using
std
::
is_class
;
using
std
::
is_class
;
using
std
::
is_class_v
;
using
std
::
is_const_v
;
using
std
::
is_pointer
;
using
std
::
is_pointer
;
using
std
::
is_reference
;
using
std
::
is_reference
;
using
std
::
is_reference_v
;
using
std
::
is_trivially_copyable
;
using
std
::
is_trivially_copyable
;
using
std
::
is_trivially_copyable_v
;
using
std
::
is_unsigned
;
using
std
::
is_unsigned
;
using
std
::
remove_const_t
;
using
std
::
remove_cv
;
using
std
::
remove_cv
;
using
std
::
remove_pointer
;
using
std
::
remove_pointer
;
using
std
::
remove_reference
;
using
std
::
remove_reference
;
using
std
::
is_const_v
;
using
std
::
is_reference_v
;
using
std
::
remove_const_t
;
using
std
::
is_class_v
;
using
std
::
is_trivially_copyable_v
;
using
std
::
void_t
;
using
std
::
false_type
;
using
std
::
true_type
;
using
std
::
true_type
;
using
std
::
declval
;
using
std
::
void_t
;
#endif
#endif
template
<
typename
X
,
typename
Y
>
template
<
typename
X
,
typename
Y
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment