Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f39f7d79
Commit
f39f7d79
authored
Jun 26, 2023
by
Jing Zhang
Browse files
skip b_lds
parent
720280ea
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
82 additions
and
117 deletions
+82
-117
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
...ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
+13
-34
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+69
-83
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
View file @
f39f7d79
...
...
@@ -14,9 +14,9 @@ template <index_t BlockSize,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc_K0_M_K1
,
typename
B
Block
Desc_K0_N_K1
,
typename
B
Thread
Desc_K0_N_K1
,
index_t
MPerThread
,
index_t
NPer
Thread
,
index_t
NPer
Block
,
index_t
K0PerLoop
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
...
...
@@ -32,10 +32,12 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static
constexpr
auto
M
=
ABlockDesc_K0_M_K1
{}.
GetLength
(
I1
);
static
constexpr
auto
K1
=
ABlockDesc_K0_M_K1
{}.
GetLength
(
I2
);
static
constexpr
auto
N
=
BBlock
Desc_K0_N_K1
{}.
GetLength
(
I1
);
static
constexpr
auto
N
PerThread
=
BThread
Desc_K0_N_K1
{}.
GetLength
(
I1
);
static
constexpr
auto
M0
=
M
/
MPerThread
;
static
constexpr
auto
M1
=
MPerThread
;
static
constexpr
auto
N
=
NPerBlock
;
static
constexpr
auto
N0
=
N
/
NPerThread
;
static
constexpr
auto
N1
=
NPerThread
;
...
...
@@ -51,15 +53,14 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
]
*
MPerThread
,
0
)},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
]
*
NPerThread
,
0
)}
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
]
*
MPerThread
,
0
)}
{
static_assert
(
ABlockDesc_K0_M_K1
::
IsKnownAtCompileTime
()
&&
B
Block
Desc_K0_N_K1
::
IsKnownAtCompileTime
(),
B
Thread
Desc_K0_N_K1
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ABlockDesc_K0_M_K1
{}.
GetLength
(
I0
)
==
B
Block
Desc_K0_N_K1
{}.
GetLength
(
I0
)
&&
ABlockDesc_K0_M_K1
{}.
GetLength
(
I2
)
==
B
Block
Desc_K0_N_K1
{}.
GetLength
(
I2
),
static_assert
(
ABlockDesc_K0_M_K1
{}.
GetLength
(
I0
)
==
B
Thread
Desc_K0_N_K1
{}.
GetLength
(
I0
)
&&
ABlockDesc_K0_M_K1
{}.
GetLength
(
I2
)
==
B
Thread
Desc_K0_N_K1
{}.
GetLength
(
I2
),
"wrong! E dimension not consistent
\n
"
);
static_assert
(
K0
%
K0PerLoop
==
0
,
""
);
...
...
@@ -90,27 +91,23 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
return
c_m0_m1_n0_n1_thread_cluster_idx
;
}
template
<
typename
ABlockBuffer
,
typename
B
Block
Buffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
B
Thread
Buffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
B
Block
Buffer
&
b_
block
_buf
,
const
B
Thread
Buffer
&
b_
thread
_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
B
Block
Buffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
B
Thread
Buffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
a_block_mtx
=
ABlockDesc_K0_M_K1
{};
constexpr
auto
b_block_mtx
=
BBlockDesc_K0_N_K1
{};
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
(),
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
,
b_thread_mtx_
.
GetElementSpaceSize
(),
true
>
b_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
FloatB
,
FloatC
,
...
...
@@ -126,17 +123,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_mtx
,
make_tuple
(
k0_begin
,
I0
,
I0
),
b_block_buf
,
b_thread_mtx_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
k0_begin
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
});
...
...
@@ -153,20 +143,9 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
K1
,
K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc_K0_N_K1
,
decltype
(
b_thread_mtx_
),
Sequence
<
K0PerLoop
,
NPerThread
,
K1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
K1
,
K1
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
f39f7d79
...
...
@@ -120,20 +120,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr
auto
a_block_desc_k_m
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k_n
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k_m
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k_n
.
GetElementSpaceSize
(),
max_lds_align
);
return
2
*
(
a_block_aligned_space_size
+
b_block_aligned_space_size
)
*
sizeof
(
FloatAB
);
return
2
*
(
a_block_aligned_space_size
)
*
sizeof
(
FloatAB
);
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
...
...
@@ -397,6 +389,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
p_a_grid
,
a_grid_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
ignore
=
b_global_buf
;
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
...
...
@@ -425,26 +418,13 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr
auto
a_block_desc_k0_m0_m1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_block_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
static_assert
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
()
==
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
()
==
b_k0_n_k1_block_desc
.
GetElementSpaceSize
()
&&
"wrong!"
);
// A matrix blockwise copy
...
...
@@ -471,45 +451,36 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
a_block_desc_k0_m0_m1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
1
,
NPerBlock
,
K1
.
value
>
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
static
constexpr
auto
b_thread_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
NPerThread
>
{},
Number
<
K1
>
{}));
auto
b_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
decltype
(
b_block_desc_k0_n0_n1_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
// SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
// DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
decltype
(
b_thread_desc_k0_n0_n1_k1
),
Sequence
<
K0PerBlock
,
1
,
NPerThread
,
K1
.
value
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
// BBlockTransferSrcAccessOrder,
3
,
K1
,
1
,
false
,
true
>
(
b_grid_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
in0
,
0
,
0
),
b_block_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
true
>
(
b_grid_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
in0
,
get_thread_local_1d_id
()
*
NPerThread
,
0
));
static
constexpr
auto
b_k0_n_k1_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerThread
>
{},
Number
<
K1
>
{}));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
const
auto
blockwise_gemm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_
block
_desc
),
decltype
(
b_k0_n_k1_
thread
_desc
),
MPerThread
,
NPer
Thread
,
NPer
Block
,
KPerThread
>
{};
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
...
...
@@ -522,11 +493,13 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
auto
b_thread_odd_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
auto
b_thread_even_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_k0_n_k1_thread_desc
.
GetElementSpaceSize
());
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
...
...
@@ -536,27 +509,25 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
c_thread_buf
.
Clear
();
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
constexpr
auto
b_
block
_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
constexpr
auto
b_
thread
_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
,
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
,
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
+
b_block_aligned_space_size
,
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
,
b_thread_desc_k0_n0_n1_k1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
);
}
if
constexpr
(
HasMainKBlockLoop
)
...
...
@@ -572,40 +543,50 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_thread_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
,
b_thread_desc_k0_n0_n1_k1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_
block
_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_
thread
_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_odd_buf
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_thread_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
,
b_thread_desc_k0_n0_n1_k1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_
block
_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_
thread
_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
...
...
@@ -615,32 +596,37 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_thread_slice_copy_step
);
block_sync_lds
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
,
b_thread_desc_k0_n0_n1_k1
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_
block
_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_
thread
_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_odd_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_
block
_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_
thread
_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_
block
_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_
thread
_even_buf
,
c_thread_buf
);
}
// output: register to global memory
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment