Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
114f9298
Commit
114f9298
authored
Nov 24, 2021
by
ltqin
Browse files
using atomic
parent
b7ec2078
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
84 additions
and
42 deletions
+84
-42
device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
...eration/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
+3
-3
device_operation/include/device_gemm_splitk_xdl.hpp
device_operation/include/device_gemm_splitk_xdl.hpp
+81
-39
No files found.
device_operation/device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn.cpp
View file @
114f9298
...
@@ -30,9 +30,9 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
...
@@ -30,9 +30,9 @@ using device_gemm_xdl_instance_f32_f32_f32_mk_kn_mn = std::tuple<
//#################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//#################| Type| Type| Type| Type| | | | Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ThreadSlice| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| SrcDstVectorDim| DstScalar| AddExtraM| AddExtraN|
//#################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//#################| | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_N_K1| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| Lengths_K0_N_K1| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmSplitKXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
7
,
1
,
true
,
true
,
false
,
1
>
,
DeviceGemmSplitKXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
7
,
1
,
true
,
true
,
360
>
,
DeviceGemmSplitKXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
7
,
1
,
true
,
true
,
true
,
36
0
>
,
DeviceGemmSplitKXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
7
,
1
,
true
,
true
,
48
0
>
,
DeviceGemmSplitKXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
7
,
1
,
true
,
true
,
true
,
48
0
>
DeviceGemmSplitKXdl
<
F32
,
F32
,
F32
,
F32
,
Row
,
Row
,
Row
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
4
,
4
,
S
<
1
,
1
,
2
,
4
>
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
4
,
7
,
1
,
true
,
true
,
72
0
>
// clang-format on
// clang-format on
>
;
>
;
#else
#else
...
...
device_operation/include/device_gemm_splitk_xdl.hpp
View file @
114f9298
...
@@ -11,6 +11,10 @@
...
@@ -11,6 +11,10 @@
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
#include "gridwise_gemm_xdlops_v2r4.hpp"
#ifndef CK_RUN_KERNEL_AND_TIME
#define CK_RUN_KERNEL_AND_TIME 0
#endif
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
...
@@ -49,7 +53,6 @@ template <typename ADataType,
...
@@ -49,7 +53,6 @@ template <typename ADataType,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
ck
::
index_t
CThreadTransferDstScalarPerVector
,
bool
ABlockLdsAddExtraM
,
bool
ABlockLdsAddExtraM
,
bool
BBlockLdsAddExtraN
,
bool
BBlockLdsAddExtraN
,
bool
IsSplitK
,
ck
::
index_t
DesiredGridSize
>
ck
::
index_t
DesiredGridSize
>
struct
DeviceGemmSplitKXdl
:
public
DeviceGemm
struct
DeviceGemmSplitKXdl
:
public
DeviceGemm
{
{
...
@@ -63,7 +66,7 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
...
@@ -63,7 +66,7 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
static
auto
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
,
int
KBatch
,
int
KPad
)
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
,
int
KBatch
,
int
KPad
)
{
{
assert
(
K
%
K1
==
0
);
assert
(
K
Pad
%
(
K1
*
KBatch
)
==
0
);
const
index_t
K0
=
KPad
/
(
K1
*
KBatch
);
const
index_t
K0
=
KPad
/
(
K1
*
KBatch
);
...
@@ -96,7 +99,7 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
...
@@ -96,7 +99,7 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
static
auto
static
auto
MakeBGridDescriptor_KBatch_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
,
int
KBatch
,
int
KPad
)
MakeBGridDescriptor_KBatch_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
,
int
KBatch
,
int
KPad
)
{
{
assert
(
K
%
K1
==
0
);
assert
(
K
Pad
%
(
K1
*
KBatch
)
==
0
);
const
index_t
K0
=
KPad
/
(
K1
*
KBatch
);
const
index_t
K0
=
KPad
/
(
K1
*
KBatch
);
...
@@ -141,8 +144,6 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
...
@@ -141,8 +144,6 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
static
auto
GetKBatchAndKPad
(
index_t
M
,
index_t
N
,
index_t
K
)
static
auto
GetKBatchAndKPad
(
index_t
M
,
index_t
N
,
index_t
K
)
{
{
if
(
!
IsSplitK
)
return
std
::
make_tuple
(
1
,
K
);
const
auto
GridMN
=
M
*
N
/
(
MPerBlock
*
NPerBlock
);
const
auto
GridMN
=
M
*
N
/
(
MPerBlock
*
NPerBlock
);
const
index_t
KBatch
=
std
::
max
(
DesiredGridSize
/
GridMN
,
1
);
const
index_t
KBatch
=
std
::
max
(
DesiredGridSize
/
GridMN
,
1
);
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
...
@@ -405,18 +406,8 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
...
@@ -405,18 +406,8 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_main_k0_block_loop
)
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
{
#if CK_RUN_KERNEL_AND_TIME
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
Block2CTileMap
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
...
@@ -429,8 +420,55 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
...
@@ -429,8 +420,55 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
#else
nrepeat
++
;
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
block_2_ctile_map_
);
#endif
};
if
(
has_main_k0_block_loop
)
{
if
(
kbatch
==
1
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
}
}
else
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
}
}
else
{
if
(
kbatch
==
1
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
GridwiseGemm
,
GridwiseGemm
,
...
@@ -442,18 +480,22 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
...
@@ -442,18 +480,22 @@ struct DeviceGemmSplitKXdl : public DeviceGemm
remove_reference_t
<
DeviceGemmSplitKXdl
::
Block2CTileMap
>
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
Block2CTileMap
>
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
Run
(
kernel
);
nrepeat
,
}
dim3
(
grid_size
),
else
dim3
(
BlockSize
),
{
0
,
const
auto
kernel
=
kernel_gemm_xdlops_v2r4
<
arg
.
p_a_grid_
,
GridwiseGemmAtomicAdd
,
arg
.
p_b_grid_
,
ADataType
,
// TODO: distiguish A/B datatype
arg
.
p_c_grid_
,
CDataType
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
AGridDesc_K0_M_K1
>
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
BGridDesc_K0_N_K1
>
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
remove_reference_t
<
DeviceGemmSplitKXdl
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
arg
.
block_2_ctile_map_
);
remove_reference_t
<
DeviceGemmSplitKXdl
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
}
}
}
return
ave_time
;
return
ave_time
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment