Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
474733b5
Commit
474733b5
authored
Apr 23, 2021
by
Chao Liu
Browse files
updating v5r1
parent
415a4a5b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
81 additions
and
36 deletions
+81
-36
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+8
-0
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+15
-10
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+32
-20
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+10
-6
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
+8
-0
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
+8
-0
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
474733b5
...
...
@@ -518,6 +518,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABlockBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBlockBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
474733b5
...
...
@@ -228,8 +228,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
index_t
w
;
};
MatrixIndex
c_thread_begin_mtx_idx_
;
// HACK: fix this @Jing Zhang
static
constexpr
index_t
KPerThreadSubC
=
4
;
...
...
@@ -255,6 +253,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
AddressSpace
::
Vgpr
,
1
>
;
MatrixIndex
c_thread_begin_mtx_idx_
;
AThreadCopy
a_thread_copy_
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v3
()
...
...
@@ -313,9 +313,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
return
MatrixIndex
{
k_thread_id
,
h_thread_id
,
w_thread_id
};
}
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_thread
,
FloatC
*
p_c_thread
)
const
template
<
typename
ABlockBuffer
,
typename
BThreadBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BThreadBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_block_buf
=
make_dynamic_buffer
(
p_a_block
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABlockBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -334,12 +343,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert
(
HPerThread
%
HoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
// thread A, B for GEMM
FloatA
p_a_thread
[
a_thread_mtx_
.
GetElementSpaceSize
()];
auto
a_thread_buf
=
make_dynamic_buffer
(
p_a_thread
);
auto
b_thread_buf
=
make_dynamic_buffer
(
p_b_thread
);
auto
c_thread_buf
=
make_dynamic_buffer
(
p_c_thread
);
// thread A buffer for GEMM
StaticBuffer
<
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
FloatA
,
FloatB
,
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
474733b5
...
...
@@ -685,7 +685,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
// zero out threadwise output
threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread);
#el
se
#el
if
0
// register allocation for output
FloatAcc
p_c_thread
[
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()];
...
...
@@ -695,7 +695,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
decltype
(
c_k_n_ho_wo_thread_desc
),
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>>
{}
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
#elif 1
// register allocation for output
StaticBuffer
<
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_k_n_ho_wo_thread_desc
),
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>>
{}
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
#endif
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
...
...
@@ -735,7 +742,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
.
p_data_
,
b_thread_even_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
p_a_block
);
...
...
@@ -759,14 +766,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
.
p_data_
,
b_thread_odd_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
b_thread_even_buf
.
p_data_
,
p_c_thread
);
make_dynamic_buffer
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
))),
b_thread_even_buf
,
c_thread_buf
);
b_block_data_begin
+=
EPerBlock
;
...
...
@@ -777,14 +785,15 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
.
p_data_
,
b_thread_even_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
b_thread_odd_buf
.
p_data_
,
p_c_thread
);
make_dynamic_buffer
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
))),
b_thread_odd_buf
,
c_thread_buf
);
b_block_data_begin
+=
EPerBlock
;
...
...
@@ -801,30 +810,33 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
.
p_data_
,
b_thread_odd_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
b_thread_even_buf
.
p_data_
,
p_c_thread
);
make_dynamic_buffer
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
))),
b_thread_even_buf
,
c_thread_buf
);
b_block_data_begin
+=
EPerBlock
;
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
b_thread_odd_buf
.
p_data_
,
p_c_thread
);
make_dynamic_buffer
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
))),
b_thread_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
b_thread_even_buf
.
p_data_
,
p_c_thread
);
make_dynamic_buffer
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
))),
b_thread_even_buf
,
c_thread_buf
);
}
// output: register to global memory
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
474733b5
...
...
@@ -422,12 +422,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
src_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
template
<
typename
DstSliceOriginIdx
,
typename
SrcIteratorHacks
>
template
<
typename
DstBuffer
,
typename
DstSliceOriginIdx
,
typename
SrcIteratorHacks
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
Dst
Data
*
p_dst
,
Dst
Buffer
&
dst_buf
,
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
...
...
@@ -437,6 +437,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstSliceOriginIdx
>>>::
value
,
"wrong! DstSliceOrigin need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
&&
"wrong! inconsistent type"
);
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr
auto
dst_desc
=
remove_cv_t
<
remove_reference_t
<
DstDesc
>>
{};
constexpr
auto
dst_slice_origin_idx
=
DstSliceOriginIdx
{};
...
...
@@ -564,7 +568,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
i
*
src_scalar_step_in_vector
);
p_
dst
[
Number
<
dst_offset
>
{}
]
=
src_vector
.
template
AsType
<
SrcData
>()[
i
];
dst
_buf
(
Number
<
dst_offset
>
{}
)
=
src_vector
.
template
AsType
<
SrcData
>()[
i
];
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
...
...
@@ -613,12 +617,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
}
}
template
<
typename
DstSliceOriginIdx
>
template
<
typename
DstBuffer
,
typename
DstSliceOriginIdx
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
Dst
Data
*
p_dst
)
Dst
Buffer
&
dst_buf
)
{
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
...
...
@@ -628,7 +632,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
src_desc
,
p_src
,
DstDesc
{},
DstSliceOriginIdx
{},
p_
dst
,
src_iterator_hacks
);
Run
(
src_desc
,
p_src
,
DstDesc
{},
DstSliceOriginIdx
{},
dst
_buf
,
src_iterator_hacks
);
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
View file @
474733b5
...
...
@@ -209,6 +209,14 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
View file @
474733b5
...
...
@@ -180,6 +180,14 @@ struct ThreadwiseGemm_km_kn_mn_v3
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment