Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2faeaece
Commit
2faeaece
authored
Apr 13, 2022
by
j4yan
Browse files
navi_gemm_km_kn_mn_fp32 compiles and passes one test.
parent
fd7eee0d
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
141 additions
and
100 deletions
+141
-100
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp
.../tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp
include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp
+76
-79
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp
...ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp
+9
-8
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
+1
-3
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+9
-0
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp
.../gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp
+42
-6
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/gemm_dlops/CMakeLists.txt
test/gemm_dlops/CMakeLists.txt
+1
-1
test/gemm_dlops/gemm_dlops_fp32.cpp
test/gemm_dlops/gemm_dlops_fp32.cpp
+1
-2
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp
View file @
2faeaece
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer_v
2
.hpp"
#include "threadwise_tensor_slice_transfer_v
4r1
.hpp"
#include "threadwise_contraction_dlops.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/device/device_gemm_dlops.hpp
View file @
2faeaece
...
@@ -33,7 +33,8 @@ template <
...
@@ -33,7 +33,8 @@ template <
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
K0PerBlock
,
index_t
K1
,
index_t
M1PerThread
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
N1PerThread
,
index_t
KPerThread
,
index_t
KPerThread
,
...
@@ -56,17 +57,13 @@ template <
...
@@ -56,17 +57,13 @@ template <
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
typename
AGridStepHacks
,
typename
BGridStepHacks
,
typename
CGridStepHacks
,
typename
AGridMoveSliceWindowStepHacks
,
typename
BGridMoveSliceWindowStepHacks
,
enable_if_t
<
enable_if_t
<
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
&&
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
,
is_same_v
<
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
,
bool
>
=
false
>
bool
>
=
false
>
struct
DeviceGemmDlops
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
struct
DeviceGemmDlops
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -201,12 +198,12 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -201,12 +198,12 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
A
K0MK1
GridDesc
,
AGridDesc
_K0_M_K1
,
B
K0NK1
GridDesc
,
BGridDesc
_K0_N_K1
,
C
MN
GridDesc
,
CGridDesc
_M_N
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
K
0
PerBlock
,
M1PerThread
,
M1PerThread
,
N1PerThread
,
N1PerThread
,
KPerThread
,
KPerThread
,
...
@@ -228,18 +225,16 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -228,18 +225,16 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CThreadTransferDstScalarPerVector
>
;
AGridStepHacks
,
BGridStepHacks
,
using
AGridDesc_K0_M0_M1_K1
=
CGridStepHacks
,
AGridMoveSliceWindowStepHacks
,
BGridMoveSliceWindowStepHacks
>
;
using
AK0M0M1K1GridDesc
=
decltype
(
GridwiseGemm
::
MakeAK0M0M1K1GridDescriptor
(
AGridDesc_K0_M_K1
{}));
decltype
(
GridwiseGemm
::
MakeAK0M0M1K1GridDescriptor
(
AGridDesc_K0_M_K1
{}));
using
BK0N0N1K1GridDesc
=
decltype
(
GridwiseGemm
::
MakeBKN0N1GridDescriptor
(
BGridDesc_K0_N_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
using
CM0M10M11N0N10N11GridDesc
=
decltype
(
GridwiseGemm
::
MakeBK0N0N1K1GridDescriptor
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
decltype
(
GridwiseGemm
::
MakeCM0M10M11N0N10N11GridDescriptor
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCM0M10M11N0N10N11GridDescriptor
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
GridwiseGemm
::
MakeCBlockIdToM0N0BlockClusterAdaptor
(
CGridDesc_M_N
{}));
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -261,10 +256,9 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -261,10 +256,9 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_k0_m_k1_
{},
a_grid_desc_k0_m0_m1_k1_
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n0_n1_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m0_m10_m11_n0_n10_n11_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
block_2_ctile_map_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
}
N01_
{
N01
}
...
@@ -272,15 +266,19 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -272,15 +266,19 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
// b_element_op_{b_element_op},
// b_element_op_{b_element_op},
// c_element_op_{c_element_op}
// c_element_op_{c_element_op}
{
{
a_grid_desc_k0_m_k1_
=
DeviceGemm
Xdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
a_grid_desc_k0_m_k1_
=
DeviceGemm
Dlops
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemm
Xdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
b_grid_desc_k0_n_k1_
=
DeviceGemm
Dlops
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemm
Xdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
c_grid_desc_m_n_
=
DeviceGemm
Dlops
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
))
{
{
c_m0_m10_m11_n0_n10_n11_grid_desc
=
a_grid_desc_k0_m0_m1_k1_
=
GridwiseGemm
::
MakeCM0M10M11N0N10N11GridDescriptor
(
c_m_n_grid_desc
);
GridwiseGemm
::
MakeAK0M0M1K1GridDescriptor
(
a_grid_desc_k0_m_k1_
);
b_grid_desc_k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBK0N0N1K1GridDescriptor
(
b_grid_desc_k0_n_k1_
);
c_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeCM0M10M11N0N10N11GridDescriptor
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockIdToM0N0BlockClusterAdaptor
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeCBlockIdToM0N0BlockClusterAdaptor
(
c_grid_desc_m_n_
);
...
@@ -292,11 +290,15 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -292,11 +290,15 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
AK0M0M1K1GridDesc
a_k0_m0_m1_k1_grid_desc
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BK0N0N1K1GridDesc
b_k0_n0_n1_k1_grid_desc
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CM0M10M11N0N10N11GridDesc
c_m0_m10_m11_n0_n10_n11_grid_desc
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1_
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1_
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -309,36 +311,35 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -309,36 +311,35 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
using
Argument
=
DeviceGemm
Xdl
::
Argument
;
using
Argument
=
DeviceGemm
Dlops
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
std
::
cout
<<
"arg.a_grid_desc_k0_m0_m1_k1_{"
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m0_m1_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
a_grid_desc_k0_m0_m1_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m0_m1_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
std
::
cout
<<
"arg.b_grid_desc_k0_n0_n1_k1_{"
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n0_n1_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
b_grid_desc_k0_n0_n1_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n0_n1_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
))
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
),
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m
0_m1
_k1_
.
GetLength
(
I0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
...
@@ -351,10 +352,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -351,10 +352,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3
<
GridwiseGemm
,
kernel_gemm_dlops_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
remove_reference_t
<
A
K0M0M1K1
GridDesc
>
,
remove_reference_t
<
AGridDesc
_K0_M0_M1_K1
>
,
remove_reference_t
<
B
K0N0N1K1
GridDesc
>
,
remove_reference_t
<
BGridDesc
_K0_N0_N1_K1
>
,
remove_reference_t
<
CM0M10M11N0N10N11
GridDesc
>
,
remove_reference_t
<
C
GridDesc_
M0
_
M10
_
M11
_
N0
_
N10
_
N11
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
true
,
true
,
true
>
;
true
>
;
...
@@ -369,7 +370,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -369,7 +370,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
c
block
id_to_m0_n0_block_cluster_adaptor
_
);
arg
.
block
_2_ctile_map
_
);
}
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
{
...
@@ -377,10 +378,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -377,10 +378,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3
<
GridwiseGemm
,
kernel_gemm_dlops_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
remove_reference_t
<
A
K0M0M1K1
GridDesc
>
,
remove_reference_t
<
AGridDesc
_K0_M0_M1_K1
>
,
remove_reference_t
<
B
K0N0N1K1
GridDesc
>
,
remove_reference_t
<
BGridDesc
_K0_N0_N1_K1
>
,
remove_reference_t
<
CM0M10M11N0N10N11
GridDesc
>
,
remove_reference_t
<
C
GridDesc_
M0
_
M10
_
M11
_
N0
_
N10
_
N11
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
true
,
true
,
false
>
;
false
>
;
...
@@ -395,7 +396,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -395,7 +396,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
c
block
id_to_m0_n0_block_cluster_adaptor
_
);
arg
.
block
_2_ctile_map
_
);
}
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
{
...
@@ -403,10 +404,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -403,10 +404,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3
<
GridwiseGemm
,
kernel_gemm_dlops_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
remove_reference_t
<
A
K0M0M1K1
GridDesc
>
,
remove_reference_t
<
AGridDesc
_K0_M0_M1_K1
>
,
remove_reference_t
<
B
K0N0N1K1
GridDesc
>
,
remove_reference_t
<
BGridDesc
_K0_N0_N1_K1
>
,
remove_reference_t
<
CM0M10M11N0N10N11
GridDesc
>
,
remove_reference_t
<
C
GridDesc_
M0
_
M10
_
M11
_
N0
_
N10
_
N11
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
false
,
false
,
true
>
;
true
>
;
...
@@ -421,7 +422,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -421,7 +422,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
c
block
id_to_m0_n0_block_cluster_adaptor
_
);
arg
.
block
_2_ctile_map
_
);
}
}
else
else
{
{
...
@@ -429,10 +430,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -429,10 +430,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3
<
GridwiseGemm
,
kernel_gemm_dlops_v1r3
<
GridwiseGemm
,
ADataType
,
ADataType
,
CDataType
,
CDataType
,
remove_reference_t
<
A
K0M0M1K1
GridDesc
>
,
remove_reference_t
<
AGridDesc
_K0_M0_M1_K1
>
,
remove_reference_t
<
B
K0N0N1K1
GridDesc
>
,
remove_reference_t
<
BGridDesc
_K0_N0_N1_K1
>
,
remove_reference_t
<
CM0M10M11N0N10N11
GridDesc
>
,
remove_reference_t
<
C
GridDesc_
M0
_
M10
_
M11
_
N0
_
N10
_
N11
>
,
remove_reference_t
<
CBlockIdToM0N0BlockClusterAdaptor
>
,
remove_reference_t
<
DefaultBlock2CTileMap
>
,
false
,
false
,
false
>
;
false
>
;
...
@@ -447,7 +448,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -447,7 +448,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
c
block
id_to_m0_n0_block_cluster_adaptor
_
);
arg
.
block
_2_ctile_map
_
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -468,11 +469,8 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -468,11 +469,8 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
}
// polymorphic
// polymorphic
...
@@ -555,17 +553,16 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
...
@@ -555,17 +553,16 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGemm
Xdl
"
str
<<
"DeviceGemm
Dlops
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
M1PerThread
<<
", "
<<
NPerXDL
<<
", "
<<
N1PerThread
<<
", "
<<
MXdlPerWave
<<
", "
<<
KPerThread
<<
NXdlPerWave
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp
View file @
2faeaece
...
@@ -7,8 +7,9 @@
...
@@ -7,8 +7,9 @@
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v5r1.hpp"
#include "blockwise_tensor_slice_transfer_v5r1.hpp"
#include "threadwise_tensor_slice_transfer
_v2
.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp"
#include "threadwise_tensor_slice_set.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -327,7 +328,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
...
@@ -327,7 +328,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_k0_m0_m1_k1_grid_desc
),
remove_reference_t
<
decltype
(
a_k0_m0_m1_k1_grid_desc
)
>
,
decltype
(
a_k0_m0_m1_k1_block_desc
),
decltype
(
a_k0_m0_m1_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
...
@@ -351,7 +352,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
...
@@ -351,7 +352,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_k0_n0_n1_k1_grid_desc
),
remove_reference_t
<
decltype
(
b_k0_n0_n1_k1_grid_desc
)
>
,
decltype
(
b_k0_n0_n1_k1_block_desc
),
decltype
(
b_k0_n0_n1_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
...
@@ -498,10 +499,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
...
@@ -498,10 +499,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
);
a_k0_m0_m1_k1_grid_desc
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_k0_n0_n1_k1_grid_desc
,
b_block_slice_copy_step
);
__syncthreads
();
__syncthreads
();
...
@@ -552,6 +551,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
...
@@ -552,6 +551,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
FloatC
,
FloatC
,
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
decltype
(
c_m0_m10_m11_n0_n10_n11_thread_desc
),
decltype
(
c_m0_m10_m11_n0_n10_n11_grid_desc
),
decltype
(
c_m0_m10_m11_n0_n10_n11_grid_desc
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
Sequence
<
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
...
@@ -569,7 +569,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
...
@@ -569,7 +569,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
in0
,
in0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
])}
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}}
.
Run
(
c_m0_m10_m11_n0_n10_n11_thread_desc
,
.
Run
(
c_m0_m10_m11_n0_n10_n11_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
View file @
2faeaece
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#pragma once
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
...
@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1
...
@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1
};
};
}
// namespace ck
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
2faeaece
...
@@ -45,3 +45,12 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE
...
@@ -45,3 +45,12 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE
install
(
TARGETS device_gemm_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_gemm_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_gemm_instance
)
clang_tidy_check
(
device_gemm_instance
)
add_library
(
device_gemm_dlops_instance SHARED device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp
)
target_compile_features
(
device_gemm_dlops_instance PUBLIC
)
set_target_properties
(
device_gemm_dlops_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_gemm_dlops_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_gemm_dlops_instance
)
library/src/tensor_operation_instance/gpu/gemm/device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp
View file @
2faeaece
...
@@ -23,16 +23,52 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -23,16 +23,52 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
=
using
device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer|
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order|
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDlops
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
//
DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0 ,3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on
// clang-format on
>
;
DeviceGemmDlops
<
F32
,
F32
,
F32
,
F32
,
Col
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
256
,
128
,
128
,
8
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>>
;
void
add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
(
void
add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
...
...
test/CMakeLists.txt
View file @
2faeaece
...
@@ -37,6 +37,7 @@ add_subdirectory(space_filling_curve)
...
@@ -37,6 +37,7 @@ add_subdirectory(space_filling_curve)
add_subdirectory
(
conv_util
)
add_subdirectory
(
conv_util
)
add_subdirectory
(
reference_conv_fwd
)
add_subdirectory
(
reference_conv_fwd
)
add_subdirectory
(
gemm
)
add_subdirectory
(
gemm
)
add_subdirectory
(
gemm_dlops
)
add_subdirectory
(
gemm_split_k
)
add_subdirectory
(
gemm_split_k
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
gemm_reduce
)
add_subdirectory
(
batched_gemm
)
add_subdirectory
(
batched_gemm
)
...
...
test/gemm_dlops/CMakeLists.txt
View file @
2faeaece
add_test_executable
(
test_gemm_dlops_fp32 gemm_fp32.cpp
)
add_test_executable
(
test_gemm_dlops_fp32 gemm_
dlops_
fp32.cpp
)
target_link_libraries
(
test_gemm_dlops_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_dlops_fp32 PRIVATE host_tensor
)
target_link_libraries
(
test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance
)
target_link_libraries
(
test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance
)
...
...
test/gemm_dlops/gemm_dlops_fp32.cpp
View file @
2faeaece
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <tuple>
#include <tuple>
#include <vector>
#include <vector>
#include "gemm_util.hpp"
#include "
../gemm/
gemm_util.hpp"
#include "config.hpp"
#include "config.hpp"
#include "print.hpp"
#include "print.hpp"
#include "device.hpp"
#include "device.hpp"
...
@@ -15,7 +15,6 @@
...
@@ -15,7 +15,6 @@
#include "host_gemm.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_xdl.hpp"
#include "device_gemm_dlops_c_shuffle.hpp"
#include "element_wise_operation.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "gemm_specialization.hpp"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment