Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8160c31a
Commit
8160c31a
authored
Jan 28, 2022
by
Chao Liu
Browse files
clean up
parent
0e67221f
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
10 deletions
+39
-10
profiler/include/profile_gemm_impl.hpp
profiler/include/profile_gemm_impl.hpp
+19
-2
test/CMakeLists.txt
test/CMakeLists.txt
+3
-5
test/split_k/main.cpp
test/split_k/main.cpp
+17
-3
No files found.
profiler/include/profile_gemm_impl.hpp
View file @
8160c31a
#pragma once
#pragma once
#include "device_gemm_instance.hpp"
//
#include "device_gemm_instance.hpp"
#include "device_gemm_splitk_xdl_instance.hpp"
//
#include "device_gemm_splitk_xdl_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
device_gemm_instance
{
namespace
device_gemm_instance
{
#if 0
template <>
template <>
void add_device_gemm_instance<float,
void add_device_gemm_instance<float,
float,
float,
...
@@ -70,6 +71,22 @@ void add_device_gemm_instance<ck::half_t,
...
@@ -70,6 +71,22 @@ void add_device_gemm_instance<ck::half_t,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
ck::tensor_layout::gemm::RowMajor>(std::vector<DeviceGemmNoOpPtr>&);
#else
void
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
#endif
}
// namespace device_gemm_instance
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace device
...
...
test/CMakeLists.txt
View file @
8160c31a
...
@@ -11,15 +11,13 @@ include_directories(BEFORE
...
@@ -11,15 +11,13 @@ include_directories(BEFORE
${
PROJECT_SOURCE_DIR
}
/external/rocm/include
${
PROJECT_SOURCE_DIR
}
/external/rocm/include
)
)
# test_magic_number_division
set
(
MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp
)
set
(
MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp
)
add_executable
(
test_magic_number_division
${
MAGIC_NUMBER_DIVISISON_SOURCE
}
)
add_executable
(
test_magic_number_division
${
MAGIC_NUMBER_DIVISISON_SOURCE
}
)
target_link_libraries
(
test_magic_number_division PRIVATE host_tensor
)
target_link_libraries
(
test_magic_number_division PRIVATE host_tensor
)
# test_split_k
set
(
SPLIT_K_SOURCE split_k/main.cpp
)
set
(
SPLIT_K_SOURCE split_k/main.cpp
)
add_executable
(
test_split_k
${
SPLIT_K_SOURCE
}
)
add_executable
(
test_split_k
${
SPLIT_K_SOURCE
}
)
target_link_libraries
(
test_split_k PRIVATE host_tensor
)
target_link_libraries
(
test_split_k PRIVATE host_tensor
)
target_link_libraries
(
test_split_k PRIVATE device_gemm_instance
)
target_link_libraries
(
test_split_k PRIVATE device_gemm_instance
)
\ No newline at end of file
test/split_k/main.cpp
View file @
8160c31a
...
@@ -8,11 +8,9 @@
...
@@ -8,11 +8,9 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_gemm_instance.hpp"
#include "host_gemm.hpp"
#include "host_gemm.hpp"
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
#include "device_gemm_splitk_xdl_instance.hpp"
#include "device_gemm_xdl_splitk.hpp"
#include "device_gemm_splitk_xdl.hpp"
enum
GemmMatrixLayout
enum
GemmMatrixLayout
{
{
...
@@ -33,6 +31,7 @@ static std::vector<std::vector<bool>>& GetLayoutType()
...
@@ -33,6 +31,7 @@ static std::vector<std::vector<bool>>& GetLayoutType()
return
LayOut
;
return
LayOut
;
}
}
#if 0
static void add_device_gemm_instance_mk_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
static void add_device_gemm_instance_mk_kn_mn(std::vector<DeviceGemmNoOpPtr>& gemm_ptrs)
{
{
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
ck::tensor_operation::device::device_gemm_instance::add_device_splitk_gemm_instance<
...
@@ -84,10 +83,23 @@ static auto& GetAddDeviceGemmInstance()
...
@@ -84,10 +83,23 @@ static auto& GetAddDeviceGemmInstance()
add_device_gemm_instance_km_nk_mn};
add_device_gemm_instance_km_nk_mn};
return AddDeviceGemmInstance;
return AddDeviceGemmInstance;
}
}
#else
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
void
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
#endif
static
void
add_device_gemm_instance
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
gemm_ptrs
,
int
layout
)
static
void
add_device_gemm_instance
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
gemm_ptrs
,
int
layout
)
{
{
#if 0
GetAddDeviceGemmInstance()[layout](gemm_ptrs);
GetAddDeviceGemmInstance()[layout](gemm_ptrs);
#else
if
(
layout
==
2
)
{
add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances
(
gemm_ptrs
);
}
#endif
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -150,6 +162,7 @@ int main(int argc, char* argv[])
...
@@ -150,6 +162,7 @@ int main(int argc, char* argv[])
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
}
};
};
Tensor
<
float
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
LayOut
[
layout
][
0
]));
Tensor
<
float
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
LayOut
[
layout
][
0
]));
Tensor
<
float
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
LayOut
[
layout
][
1
]));
Tensor
<
float
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
LayOut
[
layout
][
1
]));
Tensor
<
float
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
LayOut
[
layout
][
2
]));
Tensor
<
float
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
LayOut
[
layout
][
2
]));
...
@@ -213,6 +226,7 @@ int main(int argc, char* argv[])
...
@@ -213,6 +226,7 @@ int main(int argc, char* argv[])
success
=
true
;
success
=
true
;
}
}
}
}
if
(
success
)
if
(
success
)
{
{
std
::
cout
<<
"test split k : Pass"
<<
std
::
endl
;
std
::
cout
<<
"test split k : Pass"
<<
std
::
endl
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment