Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a956d60e
"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "c2714fcbfd600c2a13efbc42bab95b49b0b4fa33"
Commit
a956d60e
authored
Jun 12, 2022
by
wangshaojie6
Browse files
use tuple for skip both lds
parent
244b9ffb
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
91 additions
and
75 deletions
+91
-75
example/01_gemm/gemm_xdl_skip_all_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_all_lds_fp16.cpp
+7
-6
include/ck/config.hpp
include/ck/config.hpp
+2
-2
include/ck/tensor_operation/gpu/device/device_gemm_xdl_skip_all_lds.hpp
...sor_operation/gpu/device/device_gemm_xdl_skip_all_lds.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_all_lds_v1.hpp
...eration/gpu/grid/gridwise_gemm_xdlops_skip_all_lds_v1.hpp
+79
-67
No files found.
example/01_gemm/gemm_xdl_skip_all_lds_fp16.cpp
View file @
a956d60e
...
...
@@ -43,13 +43,14 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSkipAllLds
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BThreadTransfer| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per|
MultiK0|
K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BThreadTransfer| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block|
|
| XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| SrcScalar| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | |
|
| | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| | PerVector|
//###########| | | | | | | | | | | | | | | |
|
| | | | | | | | | | | | | | |
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 64, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 7, 1>;
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
16
,
16
,
4
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
8
,
7
,
1
>
;
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 8, 7, 1>;
//< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 8, 7, 1>;
<
F16
,
F16
,
F16
,
F32
,
Row
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
64
,
16
,
16
,
4
,
8
,
8
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
8
,
7
,
1
>
;
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
...
...
include/ck/config.hpp
View file @
a956d60e
...
...
@@ -15,7 +15,7 @@
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
#define CK_MIN_BLOCK_PER_CU
2
#define CK_MIN_BLOCK_PER_CU
1
#endif
// check GPU target
...
...
@@ -98,7 +98,7 @@
#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0
// experimental feature: buffer load/store/atomic-add/ OOB trick
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
1
#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
0
#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_skip_all_lds.hpp
View file @
a956d60e
...
...
@@ -32,6 +32,7 @@ template <typename ADataType,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
MultiK0
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
...
...
@@ -189,6 +190,7 @@ struct DeviceGemmXdlSkipAllLds
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MultiK0
,
MPerXDL
,
NPerXDL
,
K1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
a956d60e
...
...
@@ -3,6 +3,7 @@
namespace
ck
{
// N-stage prefetch
template
<
index_t
NumPrefetch
>
struct
GridwiseGemmPipeline_v2
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_all_lds_v1.hpp
View file @
a956d60e
...
...
@@ -46,18 +46,19 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
//__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
// p_shared,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
// p_shared,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
...
...
@@ -86,6 +87,7 @@ template <index_t BlockSize,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MultiK0
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
...
...
@@ -115,7 +117,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
MultiK0
=
4
*
1
;
//
static constexpr auto MultiK0 =
16
* 1;
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
...
...
@@ -227,11 +229,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
(
MultiK0
*
K0PerBlock
)
)
>
1
;
const
bool
has_main_k0_block_loop
=
K0
>
(
MultiK0
*
K0PerBlock
);
return
has_main_k0_block_loop
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
index_t
CalculateResMainK0BlockLoop
(
index_t
K0
)
{
const
index_t
res_main_k0_block_loop
=
(
K0
/
K0PerBlock
)
%
MultiK0
;
return
res_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_K0_K1_K2_M0_M1_M2_M3_K3
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
)
{
...
...
@@ -396,7 +406,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
using
AGridDesc_K0_K1_K2_M0_M1_M2_M3_K3
=
decltype
(
MakeAGridDescriptor_K0_K1_K2_M0_M1_M2_M3_K3
(
AGridDesc_K0_M_K1
{}));
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
...
@@ -420,6 +431,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
//const auto ResMainK0BlockLoop = CalculateResMainK0BlockLoop(K0);
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
@@ -444,11 +457,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
I1
,
// NPerXdlops
Number
<
K1
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
.
GetElementSpaceSize
(),
true
>
a_thread_buf
[
MultiK0
];
//, a_thread_buf_1, a_thread_buf_2, a_thread_buf_3;
auto
a_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
ignore
=
i
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
.
GetElementSpaceSize
(),
true
>
{};
},
Number
<
MultiK0
>
{});
//StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAB,
// a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3.GetElementSpaceSize(),
// true>
// a_thread_buf[MultiK0];
ignore
=
b_element_op
;
// B matrix threadwise copy
...
...
@@ -462,11 +484,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
I1
,
// NPerXdlops
Number
<
K1
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
true
>
b_thread_buf
[
MultiK0
];
//_0, b_thread_buf_1, b_thread_buf_2, b_thread_buf_3;
auto
b_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
ignore
=
i
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
true
>
{};
},
Number
<
MultiK0
>
{});
//StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAB,
// b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize(),
// true>
// b_thread_buf[MultiK0];
const
auto
wave_id
=
GetWaveIdx
();
const
auto
wave_k_n_id
=
GetWaveKNIdx
(
wave_id
[
I2
]);
...
...
@@ -564,58 +596,55 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
// constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock * MultiK0, 0, 0);
constexpr
auto
a_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
// preload data to regiester and LDS
if
constexpr
(
HasMainK0BlockLoop
)
{
// Read
index_t
i_pre
=
0
;
do
{
static_for
<
0
,
MultiK0
,
1
>
{}([
&
](
auto
i_pre
)
{
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
[
i_pre
]
);
b_thread_buf
(
Number
<
i_pre
>
{})
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
[
i_pre
]);
a_thread_buf
(
Number
<
i_pre
>
{}));
asm
volatile
(
"s_nop 0"
::
);
// Move
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
i_pre
++
;
}
while
(
i_pre
<
MultiK0
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
)
;
});
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainK0BlockLoop
)
{
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
index_t
i
=
0
;
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
index_t
i
=
0
;
do
{
index_t
i_k
=
0
;
do
{
blockwise_gemm
.
Run
(
a_thread_buf
[
i_k
],
b_thread_buf
[
i_k
],
c_thread_buf
);
static_for
<
0
,
MultiK0
,
1
>
{}([
&
](
auto
i_k
)
{
blockwise_gemm
.
Run
(
a_thread_buf
(
Number
<
i_k
>
{}),
b_thread_buf
(
Number
<
i_k
>
{}),
c_thread_buf
);
asm
volatile
(
"s_nop 0"
::
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
[
i_k
]
);
b_thread_buf
(
Number
<
i_k
>
{})
);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
[
i_k
]
);
a_thread_buf
(
Number
<
i_k
>
{})
);
asm
volatile
(
"s_nop 0"
::
);
...
...
@@ -623,8 +652,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
b_thread_slice_copy_step
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
i_k
++
;
}
while
(
i_k
<
MultiK0
);
});
i
+=
MultiK0
;
}
while
(
i
<
(
K0BlockMainLoop
-
MultiK0
));
...
...
@@ -632,26 +660,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_all_lds_v1
// tail
{
//index_t loop_num = ResMainK0BlockLoop == 0 ? MultiK0 : ResMainK0BlockLoop;
static_for
<
0
,
MultiK0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_gemm
.
Run
(
a_thread_buf
[
i
],
b_thread_buf
[
i
],
c_thread_buf
);
if
constexpr
(
i
<
MultiK0
-
4
)
{
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
[
i
]);
a_threadwise_copy
.
Run
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_grid_buf
,
a_thread_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
[
i
]);
// only move b windows
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
a_threadwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_k1_k2_m0_m1_m2_m3_k3
,
a_thread_slice_copy_step
);
}
blockwise_gemm
.
Run
(
a_thread_buf
(
Number
<
i
>
{}),
b_thread_buf
(
Number
<
i
>
{}),
c_thread_buf
);
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment