Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
580e9484
Commit
580e9484
authored
Jul 17, 2022
by
wangshaojie6
Browse files
add skip lds pipeline
parent
dbc971be
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
609 additions
and
166 deletions
+609
-166
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
+318
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_c_shuffle.hpp
...ion/gpu/device/device_batched_gemm_gemm_xdl_c_shuffle.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
...or_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
+145
-166
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_skip_lds.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_skip_lds.hpp
+145
-0
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
0 → 100644
View file @
580e9484
#pragma once
#include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0K0BN0N1N2N3K1BlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
KPack
;
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
K0PerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPerThread
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0K0BN0N1N2N3K1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_M0_M1_M2_K
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
void
MoveABlockSliceWindow
()
{
a_thread_copy_
.
MoveSrcSliceWindow
(
a_block_desc_m0_m1_m2_k
,
make_multi_index
(
0
,
0
,
0
,
K0PerBlock
*
KPack
));
}
__device__
void
ResetABlockStartWindow
()
{
a_thread_copy_
.
SetSrcCoord
(
CalculateAThreadOriginDataIndex
());
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
constexpr
index_t
k0
=
k
/
KPack
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
k0
,
n0
,
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
private:
// A[M0, M1, M2, KPerThread]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// B[N0, N1, N2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0PerThread
>
{},
// KPerThread
Number
<
NRepeat
>
{},
// repeat
Number
<
KPack
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_c_shuffle.hpp
View file @
580e9484
...
@@ -246,6 +246,7 @@ struct DeviceBatchedGemmGemmCShuffleXdl : public DeviceBatchedGemmGemmCShuffle<A
...
@@ -246,6 +246,7 @@ struct DeviceBatchedGemmGemmCShuffleXdl : public DeviceBatchedGemmGemmCShuffle<A
VElementwiseOperation
,
VElementwiseOperation
,
PElementwiseOperation
PElementwiseOperation
OElementwiseOperation
,
OElementwiseOperation
,
NumPrefetch
,
QKMPerBlock
,
QKMPerBlock
,
QKNPerBlock
,
QKNPerBlock
,
QKMPerXDL
,
QKMPerXDL
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
View file @
580e9484
...
@@ -26,6 +26,7 @@ template <index_t BlockSize,
...
@@ -26,6 +26,7 @@ template <index_t BlockSize,
typename
C0ElementwiseOperation
,
typename
C0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1ElementwiseOperation
,
typename
C1ElementwiseOperation
,
index_t
NumGemmKPrefetchStage
,
index_t
M0PerBlock
,
index_t
M0PerBlock
,
index_t
N0PerBlock
,
index_t
N0PerBlock
,
index_t
M0PerXDL
,
index_t
M0PerXDL
,
...
@@ -82,6 +83,9 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -82,6 +83,9 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
AK1
>
{};
static
constexpr
auto
K1
=
Number
<
AK1
>
{};
// gemm1 K1
static
constexpr
auto
AccK1
=
I4
;
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
M0Waves
=
M0PerBlock
/
(
M0XdlPerWave
*
M0PerXDL
);
static
constexpr
index_t
M0Waves
=
M0PerBlock
/
(
M0XdlPerWave
*
M0PerXDL
);
static
constexpr
index_t
N0Waves
=
N0PerBlock
/
(
N0XdlPerWave
*
N0PerXDL
);
static
constexpr
index_t
N0Waves
=
N0PerBlock
/
(
N0XdlPerWave
*
N0PerXDL
);
...
@@ -97,6 +101,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -97,6 +101,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipelineSkipLds
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
constexpr
auto
max_lds_align
=
AK1
;
constexpr
auto
max_lds_align
=
AK1
;
...
@@ -353,7 +359,13 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -353,7 +359,13 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
=
using
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
=
decltype
(
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
BGridDesc_K0_N_K1
{}));
decltype
(
MakeB0GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
BGridDesc_K0_N_K1
{}));
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
)
{
}
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
__device__
static
void
...
@@ -392,6 +404,9 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -392,6 +404,9 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
...
@@ -439,7 +454,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -439,7 +454,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
FloatAB
,
FloatAB
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
true
>
true
>
b_thread_
1st_buf
,
b_thread_2nd_buf
,
b_thread_3rd_buf
,
b_thread_4th_buf
;
b_thread_
buf
[
MultiK0
]
;
const
auto
wave_id
=
GetWaveIdx
();
const
auto
wave_id
=
GetWaveIdx
();
const
auto
wave_k_n_id
=
GetWaveKNIdx
(
wave_id
[
I2
]);
const
auto
wave_k_n_id
=
GetWaveKNIdx
(
wave_id
[
I2
]);
...
@@ -512,173 +527,137 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
...
@@ -512,173 +527,137 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
// gridwise GEMM pipeline
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
*
MultiK0
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
*
MultiK0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
// preload data to regiester and LDS
{
// Read
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_1st_buf
);
// Move
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// a data write to lds
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
// load 2nd a matrix data
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_2nd_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
// main body
if
constexpr
(
HasMainK0BlockLoop
)
{
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
(
MultiK0
*
K0PerBlock
));
index_t
i
=
0
;
do
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
blockwise_gemm
.
ResetABlockStartWindow
();
block_sync_lds
();
static_for
<
0
,
MultiK0
,
BaseMultK0
>
{}([
&
](
auto
)
{
// 1st
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_3rd_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_1st_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
s_nop
();
// 2nd
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_4th_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_2nd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
s_nop
();
// 3rd
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_1st_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_3rd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
s_nop
();
// 4th
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_2nd_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_4th_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
// move a and b window
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
i
+=
1
;
}
while
(
i
<
(
K0BlockMainLoop
-
1
));
}
// tail
// gridwise GEMM pipeline
{
static_assert
(
std
::
is_default_constructible_v
<
GridwiseGemmPipe
>
);
block_sync_lds
();
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipe
{};
blockwise_gemm
.
ResetABlockStartWindow
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
static_for
<
0
,
MultiK0
,
BaseMultK0
>
{}([
&
](
auto
i
)
{
(
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
a_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
// 1st
KPerBlock
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
>
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
(
a_grid_desc_k0_m_k1
,
b_thread_3rd_buf
);
a_block_desc_k0_m_k1
,
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
a_blockwise_copy
,
b_thread_slice_copy_step
);
a_grid_buf
,
a_block_buf
,
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_1st_buf
,
c_thread_buf
);
a_block_slice_copy_step
,
blockwise_gemm
.
MoveABlockSliceWindow
();
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
// 2nd
b_threadwise_copy
,
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
blockwise_gemm
,
b_thread_4th_buf
);
c_thread_buf
,
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
num_k_block_main_loop
);
b_thread_slice_copy_step
);
// gemm 1 O=PV
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_2nd_buf
,
c_thread_buf
);
// Gemm1
blockwise_gemm
.
MoveABlockSliceWindow
();
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to A data type
// 3rd
constexpr
auto
acc_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
AccK1
,
0
,
0
);
if
constexpr
(
i
<
MultiK0
-
BaseMultK0
)
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
{
constexpr
auto
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
b_grid_buf
,
//constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
constexpr
auto
m0
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
b_thread_1st_buf
);
constexpr
auto
n0
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
constexpr
auto
m1
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
b_thread_slice_copy_step
);
constexpr
auto
n1
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
}
constexpr
auto
m2
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
m3
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_3rd_buf
,
c_thread_buf
);
constexpr
auto
m4
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
blockwise_gemm
.
MoveABlockSliceWindow
();
constexpr
auto
n2
=
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
// 4th
// acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 to a1_thread_desc_k0_m_k1
if
constexpr
(
i
<
MultiK0
-
BaseMultK0
)
// m0_m1_m2_m3 -> k0
{
// n0_n1_n2 -> m
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
// m4 -> k1
b_grid_buf
,
// typical case: m0 = MRepeat, n0 = NRepeat, m4 = 4, the others are all 1
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
constexpr
auto
a1_thread_desc_k0_m_k1
=
transform_tensor_descriptor
(
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
acc_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
b_thread_2nd_buf
);
make_tuple
(
make_merge_transform
(
make_tuple
(
m0
,
m1
,
m2
,
m3
)),
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_merge_transform
(
make_tuple
(
n0
,
n1
,
n2
)),
b_thread_slice_copy_step
);
make_pass_through_transform
(
m4
)),
}
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
,
7
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_4th_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
// A1 matrix blockwise copy
});
// actually a threadwise copy. this variant needs to support RunRead() and RunWrite()
}
// TODO ANT: real blockwise copy from c_block_desc to c_thread_desc
}
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v3r1
<
Sequence
<
m0
*
m1
*
m2
*
m3
,
n0
*
n1
*
n2
,
m4
>
{},
// ThreadSliceLengths
tensor_operation
::
element_wise
::
PassThrough
,
// SrcElementwiseOperation
tensor_operation
::
element_wise
::
PassThrough
,
// DstElementwiseOperation
InMemoryDataOperationEnum
::
Set
,
// DstInMemOp
FloatGemmAcc
,
// SrcData
FloatAB
,
// DstData
a1_thread_desc_k0_m_k1
,
// SrcDesc
a1_thread_desc_k0_m_k1
,
// DstDesc
Sequence
<
1
,
0
,
2
>
,
// SrcDimAccessOrder
Sequence
<
1
,
0
,
2
>
,
// DstDimAccessOrder
2
,
// SrcVectorDim
2
,
// DstVectorDim
m4
,
// SrcScalarPerVector
m4
,
// DstScalarPerVector
1
,
// SrcScalarStrideInVector
1
,
// DstScalarStrideInVector
false
,
// ThreadTransferSrcResetCoordinateAfterRun
true
,
// ThreadTransferDstResetCoordinateAfterRun
NumGemmKPrefetchStage
>
(
a1_thread_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
a1_thread_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
B1K0
,
Gemm1NPerBlock
,
B1K1
>
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b1_grid_desc_bk0_n_bk1
),
decltype
(
b1_block_desc_bk0_n_bk1
),
B1BlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
B1BlockTransferSrcVectorDim
,
2
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
b1_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
gemm1_n_block_data_idx_on_grid
,
0
),
b_element_op
,
b1_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetElementSpaceSize
());
// reuse LDS space for gemm0's a_block_buf
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// output: register to global memory
// output: register to global memory
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_skip_lds.hpp
0 → 100644
View file @
580e9484
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace
ck
{
__device__
void
s_nop
()
{
#if 1
asm
volatile
(
"\
s_nop 0
\n
\
"
::
);
#else
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
struct
GridwiseGemmPipelineSkipLds
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
>=
2
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
2
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BThreadDesc
,
typename
BThreadTransfer
,
typename
BGridBuffer
,
typename
BThreadBuffer
,
typename
BThreadTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
index_t
MultK0
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BThreadDesc
&
b_thread_desc
,
BThreadTransfer
&
b_threadwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BThreadBuffer
&
b_thread_buf
[
MultK0
],
const
BThreadTransferStep
&
b_thread_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// preload data to regiester and LDS
// Read
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// Move
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_slice_copy_step
);
static_for
<
0
,
MultK0
,
1
>
{}([
&
](
auto
i_load_b
){
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
[
i_load_b
]);
s_nop
();
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_thread_slice_copy_step
);
});
// Initialize C
c_thread_buf
.
Clear
();
// a data write to lds
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
// main body
if
constexpr
(
HasMainK0BlockLoop
)
{
index_t
i
=
0
;
do
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
blockwise_gemm
.
ResetABlockStartWindow
();
block_sync_lds
();
static_for
<
0
,
MultiK0
,
1
>
{}([
&
](
auto
i_main
)
{
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
[
i_main
],
c_thread_buf
);
// 1st
b_threadwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
[
i_main
]);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_thread_slice_copy_step
);
blockwise_gemm
.
MoveABlockSliceWindow
();
s_nop
();
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
// move a and b window
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_slice_copy_step
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
1
));
// tail
{
block_sync_lds
();
blockwise_gemm
.
ResetABlockStartWindow
();
static_for
<
0
,
MultiK0
,
1
>
{}([
&
](
auto
i_tail
)
{
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
[
i_tail
],
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_buf
[
i_tail
],
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
});
}
}
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment