Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
888f1d68
Commit
888f1d68
authored
Apr 21, 2021
by
Chao Liu
Browse files
replace raw pointer with DynamicBuffer in blockwise and threadwise gemm
parent
35d68cf8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
62 additions
and
42 deletions
+62
-42
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+21
-20
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+17
-13
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+15
-8
composable_kernel/include/utility/buffer.hpp
composable_kernel/include/utility/buffer.hpp
+8
-0
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+1
-1
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
888f1d68
...
...
@@ -503,10 +503,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
template
<
typename
CThreadBuffer
>
__device__
void
Run_pipelined_2x2
(
const
FloatA
*
p_
a_block
,
const
FloatB
*
p_
b_block
,
CThreadBuffer
c_thread_buf
)
const
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run_pipelined_2x2
(
const
ABlockBuffer
&
a_block
_buf
,
const
BBlockBuffer
&
b_block
_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -548,8 +548,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
FloatA
p_a_thread
[
a_thread_mtx_desc_
.
GetElementSpaceSize
()];
FloatB
p_b_thread
[
b_thread_mtx_desc_
.
GetElementSpaceSize
()];
auto
a_thread_buf
=
make_dynamic_buffer
<
FloatA
>
(
p_a_thread
);
auto
b_thread_buf
=
make_dynamic_buffer
<
FloatB
>
(
p_b_thread
);
auto
a_thread_buf
=
make_dynamic_buffer
(
p_a_thread
);
auto
b_thread_buf
=
make_dynamic_buffer
(
p_b_thread
);
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1r1
<
FloatA
,
FloatB
,
...
...
@@ -561,7 +561,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_0
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
p_
a_block
,
a_block
_buf
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
a_thread_buf
);
...
...
@@ -569,7 +569,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_0
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
p_
b_block
,
b_block
_buf
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
b_thread_buf
);
...
...
@@ -577,7 +577,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_1
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
Number
<
0
>
{},
Number
<
NPerLevel1Cluster
>
{}),
p_
b_block
,
b_block
_buf
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
NPerThreadSubC
>
{}),
b_thread_buf
);
...
...
@@ -585,7 +585,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_1
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
Number
<
0
>
{},
Number
<
MPerLevel1Cluster
>
{}),
p_
a_block
,
a_block
_buf
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
MPerThreadSubC
>
{}),
a_thread_buf
);
...
...
@@ -611,7 +611,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_0
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
k
,
Number
<
0
>
{}),
p_
a_block
,
a_block
_buf
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
a_thread_buf
);
...
...
@@ -627,7 +627,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_0
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
k
,
Number
<
0
>
{}),
p_
b_block
,
b_block
_buf
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
b_thread_buf
);
...
...
@@ -643,7 +643,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read B_sub_1
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
k
,
Number
<
NPerLevel1Cluster
>
{}),
p_
b_block
,
b_block
_buf
,
b_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
NPerThreadSubC
>
{}),
b_thread_buf
);
...
...
@@ -651,7 +651,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// read A_sub_1
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
k
,
Number
<
MPerLevel1Cluster
>
{}),
p_
a_block
,
a_block
_buf
,
a_thread_mtx_desc_
,
make_tuple
(
Number
<
0
>
{},
Number
<
MPerThreadSubC
>
{}),
a_thread_buf
);
...
...
@@ -690,9 +690,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
}
template
<
typename
CThreadBuffer
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
CThreadBuffer
c_thread_buf
)
const
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -706,14 +707,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
2
)
{
Run_pipelined_2x2
(
p_
a_block
,
p_
b_block
,
c_thread_buf
);
Run_pipelined_2x2
(
a_block
_buf
,
b_block
_buf
,
c_thread_buf
);
}
else
{
Run_naive
(
p_
a_block
,
p_
b_block
,
c_thread_buf
);
Run_naive
(
a_block
_buf
,
b_block
_buf
,
c_thread_buf
);
}
#else
Run_naive
(
p_
a_block
,
p_
b_block
,
c_thread_buf
);
Run_naive
(
a_block
_buf
,
b_block
_buf
,
c_thread_buf
);
#endif
}
};
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
888f1d68
...
...
@@ -751,6 +751,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
FloatAB
*
p_a_block_even
=
p_a_block_double
;
FloatAB
*
p_b_block_even
=
p_b_block_double
;
FloatAB
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
FloatAB
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
auto
a_block_even_buf
=
make_dynamic_buffer
(
p_a_block_even
);
auto
b_block_even_buf
=
make_dynamic_buffer
(
p_b_block_even
);
auto
a_block_odd_buf
=
make_dynamic_buffer
(
p_a_block_odd
);
auto
b_block_odd_buf
=
make_dynamic_buffer
(
p_b_block_odd
);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
...
...
@@ -762,12 +774,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
if
constexpr
(
HasMainKBlockLoop
)
{
FloatAB
*
p_a_block_even
=
p_a_block_double
;
FloatAB
*
p_b_block_even
=
p_b_block_double
;
FloatAB
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
FloatAB
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
index_t
k_block_data_begin
=
0
;
// LDS double buffer: main body
...
...
@@ -791,7 +797,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_
a_block_even
,
p_
b_block_even
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even
_buf
,
b_block_even
_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
...
...
@@ -814,7 +820,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_
a_block_odd
,
p_
b_block_odd
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_odd
_buf
,
b_block_odd
_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
...
...
@@ -841,7 +847,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_
a_block_
double
,
p_b_block_double
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_
even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
...
...
@@ -850,16 +856,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_b_block_double
+
b_block_space_size
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_
a_block_
double
,
p_b_block_double
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_
even_buf
,
b_block_even_buf
,
c_thread_buf
);
}
// output: register to global memory
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
888f1d68
...
...
@@ -1321,12 +1321,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// Assume:
// 1. src:
// 1. src_desc is known at compile-time
// 2. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is DynamicBuffer
// 3. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// compile-time distance to src_reference_idx
//
3
. use #-iterator
//
4
. use #-iterator
// 2. dst:
// 1. dst_desc is known at compile-time
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. a reference src_reference_idx is given at run-time, src_slice_origin_idx has a
// 2. a reference dst_reference_idx is given at compile-time, dst_slice_origin_idx has a
// compile-time distance to dst_reference_idx
// 3. use direct address calculation (lower of coordinate)
...
...
@@ -1364,10 +1367,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
template
<
typename
SrcRefToOriginDisplacement
,
typename
DstRefToOriginDisplacement
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcRefToOriginDisplacement
&
,
const
Src
Data
*
p_src
,
const
Src
Buffer
&
src_buf
,
const
DstDesc
&
,
const
DstRefToOriginDisplacement
&
,
DstBuffer
&
dst_buf
)
const
...
...
@@ -1375,6 +1379,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
#if 0 // debug
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
#endif
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcRefToOriginDisplacement
>>>::
value
&&
is_known_at_compile_time
<
...
...
@@ -1462,8 +1470,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
src_desc
,
src_data_coord
);
src_tmp_buf
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_data_coord
.
GetOffset
()])
is_src_valid
?
src_buf
.
template
AsType
<
src_vector_t
>()[
src_data_coord
.
GetOffset
()]
:
src_vector_t
{
0
};
// copy data from src_tmp_buf to dst_tmp_buf (data cast data from SrcData to DstData)
...
...
composable_kernel/include/utility/buffer.hpp
View file @
888f1d68
...
...
@@ -15,6 +15,10 @@ struct StaticBuffer : public vector_type<ScalarType, N>
using
base
=
vector_type
<
ScalarType
,
N
>
;
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
template
<
typename
T
,
index_t
N
>
...
...
@@ -65,6 +69,10 @@ struct DynamicBuffer
{
return
PointerWrapper
<
X
>
{
reinterpret_cast
<
X
*>
(
p_scalar_
)};
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
template
<
typename
T
>
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
888f1d68
...
...
@@ -28,7 +28,7 @@
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS
1
#define CK_USE_LAUNCH_BOUNDS
0
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment