Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fd7eee0d
Commit
fd7eee0d
authored
Apr 13, 2022
by
j4yan
Browse files
start adding navi21 GEMM
parent
ac0d8066
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
744 additions
and
26 deletions
+744
-26
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
...ration/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
+11
-0
include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp
include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp
+579
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp
...ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp
+16
-26
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+3
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp
.../gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp
+46
-0
test/gemm_dlops/CMakeLists.txt
test/gemm_dlops/CMakeLists.txt
+15
-0
test/gemm_dlops/gemm_dlops_fp32.cpp
test/gemm_dlops/gemm_dlops_fp32.cpp
+74
-0
No files found.
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
View file @
fd7eee0d
...
...
@@ -86,6 +86,17 @@ struct BlockwiseTensorSliceTransfer_v5r1
}
}
template
<
typename
SrcBuffer
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
);
}
}
template
<
typename
DstBuffer
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
...
...
include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp
0 → 100644
View file @
fd7eee0d
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp
View file @
fd7eee0d
...
...
@@ -83,12 +83,7 @@ template <index_t BlockSize,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
>
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemmDlops_km_kn_mn_v1r3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -437,8 +432,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{}
);
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{}
);
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
);
a_blockwise_copy
.
RunWrite
(
a_k0_m0_m1_k1_block_desc
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_k0_n0_n1_k1_block_desc
,
b_block_even_buf
);
...
...
@@ -456,17 +451,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
AGridMoveSliceWindowStepHacks
{});
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
BGridMoveSliceWindowStepHacks
{});
b_block_slice_copy_step
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{}
);
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{}
);
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
c_m10_m11_n10_n11_thread_desc
,
...
...
@@ -480,17 +473,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
AGridMoveSliceWindowStepHacks
{});
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
BGridMoveSliceWindowStepHacks
{});
b_block_slice_copy_step
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{}
);
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{}
);
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
...
...
@@ -508,15 +499,15 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
,
AGridMoveSliceWindowStepHacks
{}
);
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
,
BGridMoveSliceWindowStepHacks
{}
);
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
,
AGridStepHacks
{}
);
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
,
BGridStepHacks
{}
);
a_blockwise_copy
.
RunRead
(
a_k0_m0_m1_k1_grid_desc
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_k0_n0_n1_k1_grid_desc
,
b_global_buf
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
...
...
@@ -583,8 +574,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_m0_m10_m11_n0_n10_n11_grid_desc
,
c_grid_buf
,
CGridStepHacks
{});
c_grid_buf
);
}
}
};
...
...
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
fd7eee0d
...
...
@@ -33,6 +33,9 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp;
)
add_library
(
device_gemm_instance SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp
0 → 100644
View file @
fd7eee0d
#include <stdlib.h>
#include "config.hpp"
#include "device_gemm_dlops.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
=
std
::
tuple
<
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order|
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDlops
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
{});
}
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
test/gemm_dlops/CMakeLists.txt
0 → 100644
View file @
fd7eee0d
add_test_executable
(
test_gemm_dlops_fp32 gemm_fp32.cpp
)
target_link_libraries
(
test_gemm_dlops_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance
)
# add_test_executable(test_gemm_dlops_fp16 gemm_fp16.cpp)
# target_link_libraries(test_gemm_dlops_fp16 PRIVATE host_tensor)
# target_link_libraries(test_gemm_dlops_fp16 PRIVATE device_gemm_dlops_instance)
#
# add_test_executable(test_gemm_dlops_bf16 gemm_bf16.cpp)
# target_link_libraries(test_gemm_dlops_bf16 PRIVATE host_tensor)
# target_link_libraries(test_gemm_dlops_bf16 PRIVATE device_gemm_dlops_instance)
#
# add_test_executable(test_gemm_dlops_int8 gemm_int8.cpp)
# target_link_libraries(test_gemm_dlops_int8 PRIVATE host_tensor)
# target_link_libraries(test_gemm_dlops_int8 PRIVATE device_gemm_dlops_instance)
test/gemm_dlops/gemm_dlops_fp32.cpp
0 → 100644
View file @
fd7eee0d
#include <algorithm>
#include <cstdlib>
#include <half.hpp>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "gemm_util.hpp"
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_dlops_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceGemmNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceGemmPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_gemm_instance
{
void
add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmNoOpPtr
>&
);
// void add_device_gemm_dlops_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
// void add_device_gemm_dlops_f32_f32_f32_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
// void add_device_gemm_dlops_f32_f32_f32_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
}
// namespace device_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
int
main
()
{
using
ADataType
=
float
;
using
BDataType
=
float
;
using
CDataType
=
float
;
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
ColumnMajor
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
bool
res
=
true
;
std
::
vector
<
DeviceGemmNoOpPtr
>
gemmPtrs
;
ck
::
tensor_operation
::
device
::
device_gemm_instance
::
add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
(
gemmPtrs
);
for
(
auto
&
gemmPtr
:
gemmPtrs
)
{
res
&=
ck
::
gemm_util
::
TestGemm
<
DeviceGemmNoOpPtr
,
ADataType
,
BDataType
,
CDataType
,
ColumnMajor
,
RowMajor
,
RowMajor
,
PassThrough
,
PassThrough
,
PassThrough
>
{}(
gemmPtr
);
}
std
::
cout
<<
"TestGemm ..... "
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
?
0
:
1
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment