Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b7419eec
Commit
b7419eec
authored
Nov 05, 2023
by
Jing Zhang
Browse files
seperate float a/b
parent
57d0ea67
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
41 deletions
+71
-41
example/01_gemm/gemm_dl_fp16.cpp
example/01_gemm/gemm_dl_fp16.cpp
+4
-5
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
...de/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
+31
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+34
-27
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+2
-1
No files found.
example/01_gemm/gemm_dl_fp16.cpp
View file @
b7419eec
...
@@ -6,12 +6,12 @@
...
@@ -6,12 +6,12 @@
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using
ADataType
=
ck
::
half_t
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half
_t
;
using
BDataType
=
int8
_t
;
using
CDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
...
@@ -23,12 +23,11 @@ static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpeciali
...
@@ -23,12 +23,11 @@ static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpeciali
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmDl
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmDl
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer|
C
ThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer|
B
ThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | Order| | |
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmMNPadding
,
64
,
1
,
512
,
2
,
4
,
1
,
8
,
1
,
S
<
1
,
1
,
1
,
4
>
,
S
<
2
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
2
,
8
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmMNPadding
,
64
,
1
,
64
,
16
,
4
,
1
,
1
,
1
,
S
<
1
,
1
,
1
,
4
>
,
S
<
16
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
4
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
3
,
4
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
View file @
b7419eec
...
@@ -80,6 +80,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -80,6 +80,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
using
GridwiseGemm
=
using
GridwiseGemm
=
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
GridwiseGemmDl_km_kn_mn_v1r3
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
...
@@ -184,10 +185,17 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -184,10 +185,17 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
float
ave_time
=
0
;
float
ave_time
=
0
;
using
ComputeType
=
ADataType
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
true
,
true
>
;
ADataType
,
BDataType
,
ComputeType
,
CDataType
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
@@ -206,8 +214,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -206,8 +214,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
true
,
false
>
;
ADataType
,
BDataType
,
ComputeType
,
CDataType
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
@@ -226,8 +239,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -226,8 +239,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
false
,
true
>
;
ADataType
,
BDataType
,
ComputeType
,
CDataType
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
@@ -246,8 +264,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
...
@@ -246,8 +264,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
}
else
else
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
false
,
false
>
;
ADataType
,
BDataType
,
ComputeType
,
CDataType
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
b7419eec
...
@@ -19,7 +19,9 @@
...
@@ -19,7 +19,9 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
ComputeType
,
typename
FloatC
,
typename
FloatC
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
...
@@ -27,8 +29,8 @@ __global__ void
...
@@ -27,8 +29,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_dl_v1r3
(
const
FloatA
B
*
__restrict__
p_a_grid
,
kernel_gemm_dl_v1r3
(
const
FloatA
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
index_t
M
,
const
index_t
M
,
const
index_t
N
,
const
index_t
N
,
...
@@ -38,9 +40,9 @@ __global__ void
...
@@ -38,9 +40,9 @@ __global__ void
const
index_t
StrideC
)
const
index_t
StrideC
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
ComputeType
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
ComputeType
p_shared_block
[
shared_block_size
];
const
auto
a_grid_desc_k0_m_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
const
auto
a_grid_desc_k0_m_k1
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
const
auto
b_grid_desc_k0_n_k1
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
const
auto
b_grid_desc_k0_n_k1
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
...
@@ -68,7 +70,8 @@ __global__ void
...
@@ -68,7 +70,8 @@ __global__ void
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
...
@@ -121,7 +124,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -121,7 +124,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr
auto
a_block_aligned_space_size
=
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k_m
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k_m
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_aligned_space_size
)
*
sizeof
(
FloatAB
);
return
2
*
(
a_block_aligned_space_size
)
*
sizeof
(
ComputeType
);
}
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
...
@@ -368,12 +371,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -368,12 +371,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
decltype
(
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
decltype
(
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
using
ComputeType
=
FloatA
;
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
__device__
static
void
Run
(
const
FloatA
B
*
__restrict__
p_a_grid
,
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
ComputeType
*
__restrict__
p_shared_block
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
&
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
&
c_grid_desc_m0_m10_m11_n0_n10_n11
,
...
@@ -423,6 +428,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -423,6 +428,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
"wrong!"
);
"wrong!"
);
ignore
=
a_global_buf
;
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
BlockSize
,
...
@@ -431,8 +438,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -431,8 +438,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatA
B
,
FloatA
,
FloatAB
,
ComputeType
,
remove_reference_t
<
decltype
(
a_grid_desc_k0_m0_m1_k1
)
>
,
remove_reference_t
<
decltype
(
a_grid_desc_k0_m0_m1_k1
)
>
,
decltype
(
a_block_desc_k0_m0_m1_k1
),
decltype
(
a_block_desc_k0_m0_m1_k1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -451,8 +458,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -451,8 +458,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
NPerThread
>
{},
Number
<
K1
>
{}));
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
NPerThread
>
{},
Number
<
K1
>
{}));
auto
b_threadwise_copy
=
auto
b_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
Float
A
B
,
ThreadwiseTensorSliceTransfer_v2
<
FloatB
,
FloatAB
,
ComputeType
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
decltype
(
b_thread_desc_k0_n0_n1_k1
),
decltype
(
b_thread_desc_k0_n0_n1_k1
),
Sequence
<
K0PerBlock
,
1
,
NPerThread
,
K1
.
value
>
,
Sequence
<
K0PerBlock
,
1
,
NPerThread
,
K1
.
value
>
,
...
@@ -470,8 +477,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -470,8 +477,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
ComputeType
,
FloatAB
,
ComputeType
,
FloatAcc
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_thread_desc
),
decltype
(
b_k0_n_k1_thread_desc
),
...
@@ -489,12 +496,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -489,12 +496,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
ComputeType
*
p_a_block_double
=
p_shared_block
;
auto
b_thread_odd_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
auto
b_thread_odd_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeType
>
(
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
auto
b_thread_even_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
auto
b_thread_even_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeType
>
(
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
// register allocation for output
// register allocation for output
...
@@ -516,8 +523,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -516,8 +523,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// LDS double buffer: preload data into LDS
// LDS double buffer: preload data into LDS
{
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
//
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
//
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
,
b_global_buf
,
...
@@ -544,7 +551,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -544,7 +551,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
//
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
,
b_global_buf
,
...
@@ -558,7 +565,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -558,7 +565,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_thread_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
//
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
// odd iteration
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
...
@@ -568,7 +575,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -568,7 +575,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
//
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
,
b_global_buf
,
...
@@ -582,7 +589,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -582,7 +589,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_thread_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
//
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
k_block_data_begin
+=
2
*
K0PerBlock
;
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
...
@@ -598,7 +605,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -598,7 +605,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
block_sync_lds
();
block_sync_lds
();
// LDS double buffer: load last data from device mem
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
//
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
,
b_global_buf
,
...
@@ -610,7 +617,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -610,7 +617,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_thread_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
//
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
block_sync_lds
();
block_sync_lds
();
...
...
script/cmake-ck-dev.sh
View file @
b7419eec
...
@@ -12,7 +12,8 @@ cmake
...
@@ -12,7 +12,8 @@ cmake
-save-temps=
$PWD
"
\
-save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
"gfx1100"
\
-D
GPU_TARGETS
=
"gfx90a"
\
-D
DL_KERNEL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
${
MY_PROJECT_SOURCE
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment