Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1b4ae8b5
Commit
1b4ae8b5
authored
Dec 16, 2021
by
ltqin
Browse files
add test
parent
982e59b3
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
335 additions
and
111 deletions
+335
-111
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
+1
-1
device_operation/include/device_gemm_splitk_xdl.hpp
device_operation/include/device_gemm_splitk_xdl.hpp
+38
-32
device_operation/include/device_gemm_xdl_instance.hpp
device_operation/include/device_gemm_xdl_instance.hpp
+80
-0
profiler/include/profile_gemm.hpp
profiler/include/profile_gemm.hpp
+1
-78
test/CMakeLists.txt
test/CMakeLists.txt
+7
-0
test/split_k/main.cpp
test/split_k/main.cpp
+208
-0
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp
View file @
1b4ae8b5
device_operation/include/device_gemm_splitk_xdl.hpp
View file @
1b4ae8b5
...
...
@@ -387,25 +387,26 @@ struct DeviceGemmSplitKXdl
{
using
Argument
=
DeviceGemmSplitKXdl
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
const
auto
kbatch
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
);
void
ShowInfo
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
const
auto
kbatch
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
...
...
@@ -426,6 +427,9 @@ struct DeviceGemmSplitKXdl
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
nrepeat
>
0
)
{
ShowInfo
(
arg
);
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
...
...
@@ -441,7 +445,9 @@ struct DeviceGemmSplitKXdl
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
if
(
kbatch
>
1
)
}
if
(
kbatch
>
1
||
nrepeat
<=
0
)
{
hipGetErrorString
(
hipMemset
(
arg
.
p_c_grid_
,
...
...
device_operation/include/device_gemm_xdl_instance.hpp
0 → 100644
View file @
1b4ae8b5
#ifndef DEVICE_GEMM_XDL_INSTANCE
#define DEVICE_GEMM_XDL_INSTANCE
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmNoOpPtr
=
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
\ No newline at end of file
profiler/include/profile_gemm.hpp
View file @
1b4ae8b5
#pragma once
#include "device_gemm_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
DeviceGemmNoOpPtr
=
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
template
<
>
void
add_device_gemm_instance
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#include "device_gemm_xdl_instance.hpp"
namespace
ck
{
namespace
profiler
{
...
...
test/CMakeLists.txt
View file @
1b4ae8b5
...
...
@@ -16,3 +16,10 @@ set(MAGIC_NUMBER_DIVISISON_SOURCE magic_number_division/main.cpp)
add_executable
(
test_magic_number_division
${
MAGIC_NUMBER_DIVISISON_SOURCE
}
)
target_link_libraries
(
test_magic_number_division PRIVATE host_tensor
)
set
(
SPLIT_K_SOURCE split_k/main.cpp
)
add_executable
(
test_split_k
${
SPLIT_K_SOURCE
}
)
target_link_libraries
(
test_split_k PRIVATE host_tensor
)
target_link_libraries
(
test_split_k PRIVATE device_gemm_instance
)
\ No newline at end of file
test/split_k/main.cpp
0 → 100644
View file @
1b4ae8b5
#include <iostream>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "device_gemm_instance.hpp"
#include "host_gemm.hpp"
#include "tensor_layout.hpp"
#include "device_gemm_xdl_instance.hpp"
#include "device_gemm_splitk_xdl.hpp"
enum
GemmMatrixLayout
{
MK_KN_MN
,
// 0
MK_NK_MN
,
// 1
KM_KN_MN
,
// 2
KM_NK_MN
,
// 3
};
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
using
GEMM_PTR
=
std
::
vector
<
DeviceGemmNoOpPtr
>
;
static
std
::
vector
<
std
::
vector
<
bool
>>
LayOut
=
{{
0
,
0
,
0
},
{
0
,
1
,
0
},
{
1
,
0
,
0
},
{
1
,
1
,
0
}};
static
void
add_device_gemm_instance_mk_kn_mn
(
GEMM_PTR
&
gemm_ptrs
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
}
static
void
add_device_gemm_instance_mk_nk_mn
(
GEMM_PTR
&
gemm_ptrs
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
}
static
void
add_device_gemm_instance_km_kn_mn
(
GEMM_PTR
&
gemm_ptrs
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
}
static
void
add_device_gemm_instance_km_nk_mn
(
GEMM_PTR
&
gemm_ptrs
)
{
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_instance
<
float
,
float
,
float
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
ColumnMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
gemm_ptrs
);
}
static
std
::
vector
<
void
(
*
)(
GEMM_PTR
&
)
>
AddDeviceGemmInstance
=
{
add_device_gemm_instance_mk_kn_mn
,
add_device_gemm_instance_mk_nk_mn
,
add_device_gemm_instance_km_kn_mn
,
add_device_gemm_instance_km_nk_mn
};
static
void
add_device_gemm_instance
(
GEMM_PTR
&
gemm_ptrs
,
int
layout
)
{
AddDeviceGemmInstance
[
layout
](
gemm_ptrs
);
}
template
<
typename
T
>
static
bool
check_out
(
const
Tensor
<
T
>&
ref
,
const
Tensor
<
T
>&
result
)
{
float
max_diff
=
1e-6
;
for
(
int
i
=
0
;
i
<
ref
.
mData
.
size
();
++
i
)
{
float
diff
=
std
::
abs
(
double
(
ref
.
mData
[
i
])
-
double
(
result
.
mData
[
i
]));
if
(
max_diff
<
diff
)
{
return
false
;
}
}
return
true
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
8
)
{
printf
(
"arg1: matrix layout (0: A[m, k] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 1: A[m, k] * B[n, k] = C[m, n];
\n
"
);
printf
(
" 2: A[k, n] * B[k, n] = C[m, n];
\n
"
);
printf
(
" 3: A[k, n] * B[n, k] = C[m, n])
\n
"
);
printf
(
"arg2 to 7: M, N, K, StrideA, StrideB, StrideC
\n
"
);
return
1
;
}
const
int
layout
=
static_cast
<
GemmMatrixLayout
>
(
std
::
stoi
(
argv
[
1
]));
const
int
M
=
std
::
stoi
(
argv
[
2
]);
const
int
N
=
std
::
stoi
(
argv
[
3
]);
const
int
K
=
std
::
stoi
(
argv
[
4
]);
const
int
StrideA
=
std
::
stoi
(
argv
[
5
]);
const
int
StrideB
=
std
::
stoi
(
argv
[
6
]);
const
int
StrideC
=
std
::
stoi
(
argv
[
7
]);
if
(
layout
>
3
||
layout
<
0
)
{
printf
(
"arg1 must be 0 ,1 ,2 or 3
\n
"
);
return
1
;
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
bool
isRevert
)
{
if
(
isRevert
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
};
Tensor
<
float
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
LayOut
[
layout
][
0
]));
Tensor
<
float
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
LayOut
[
layout
][
1
]));
Tensor
<
float
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
LayOut
[
layout
][
2
]));
Tensor
<
float
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
LayOut
[
layout
][
2
]));
// init data
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
float
>
{
-
5
,
5
},
num_thread
);
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
float
>
{
-
5
,
5
},
num_thread
);
// set zero to c_device_buf
c_m_n_device_result
.
GenerateTensorValue
(
GeneratorTensor_0
<
float
>
{},
num_thread
);
host_gemm_mk_kn_mn
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
DeviceMem
a_device_buf
(
sizeof
(
float
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_device_buf
(
sizeof
(
float
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_device_buf
(
sizeof
(
float
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
a_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
// add device GEMM instances
GEMM_PTR
gemm_ptrs
;
add_device_gemm_instance
(
gemm_ptrs
,
layout
);
bool
success
=
false
;
for
(
auto
&
gemm_ptr
:
gemm_ptrs
)
{
auto
argument_ptr
=
gemm_ptr
->
MakeArgumentPointer
(
static_cast
<
float
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
invoker_ptr
=
gemm_ptr
->
MakeInvokerPointer
();
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
0
);
c_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
if
(
!
check_out
(
c_m_n_host_result
,
c_m_n_device_result
))
{
success
=
false
;
break
;
}
success
=
true
;
}
}
if
(
success
)
{
std
::
cout
<<
"test split k : Pass"
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"test split k: Fail "
<<
std
::
endl
;
}
return
0
;
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment