Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
59462dca
"vscode:/vscode.git/clone" did not exist on "f2421221438fab2e41e29fa0cecc3520dbe6029e"
Commit
59462dca
authored
May 20, 2021
by
Jing Zhang
Browse files
use StaticBuffer of vector_type
parent
2cf1757e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
30 deletions
+17
-30
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+10
-14
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+3
-12
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+4
-4
No files found.
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
59462dca
...
...
@@ -315,8 +315,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr
index_t
BlkSize
=
OutputLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
OutputLayout
.
GetNumBlks
();
constexpr
auto
c_mr_nr_nb_bk_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
NumBlks
>
{},
Number
<
BlkSize
>
{}));
// constexpr auto c_mr_nr_nb_bk_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(Number<MRepeat>{},
// Number<NRepeat>{}, Number<NumBlks>{}, Number<BlkSize>{}));
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
...
...
@@ -337,9 +338,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
constexpr
auto
c_vec_size
=
c_mr_nr_nb_bk_thread_desc
.
GetElementSpaceSize
();
vector_type
<
float
,
c_vec_size
>
c_thread_buf
;
StaticBuffer
<
AddressSpace
::
Vgpr
,
vector_type
<
float
,
NumBlks
*
BlkSize
>
,
MRepeat
*
NRepeat
>
c_thread_buf
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
...
...
@@ -475,23 +475,19 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr
index_t
M1
=
OutputLayout
.
N1
();
constexpr
index_t
M2
=
OutputLayout
.
M0
();
// static_assert(M0 == 4 && M1 == 2 && M2 == 4, "");
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
M0
>
{},
Number
<
1
>
{},
Number
<
M2
>
{},
Number
<
1
>
{}));
// s
tatic
_assert(BlkSize == 16 && NumBlks == 4, "")
;
S
tatic
Buffer
<
AddressSpace
::
Vgpr
,
float
,
BlkSize
>
c_blk_buf_
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
mr_i
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
nr_i
)
{
static_for
<
0
,
NumBlks
,
1
>
{}([
&
](
auto
blk_i
)
{
StaticBuffer
<
AddressSpace
::
Vgpr
,
float
,
BlkSize
>
c_thread_buf_
;
static_for
<
0
,
BlkSize
,
1
>
{}([
&
](
auto
j
)
{
c_
thread
_buf_
(
j
)
=
c_thread_buf
.
template
AsType
<
float
>()
[
Number
<
c_
mr_
nr_nb_bk_thread_desc
.
CalculateOffset
(
make_tuple
(
mr_i
,
nr_i
,
blk_i
,
j
))
>
{}];
c_
blk
_buf_
(
j
)
=
c_thread_buf
[
Number
<
mr_
i
*
NRepeat
+
nr_i
>
{}]
.
template
AsType
<
float
>()[
Number
<
blk_i
*
BlkSize
+
j
>
{}];
});
// calculate origin of thread output tensor on global memory
...
...
@@ -526,7 +522,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
b_thread_data_on_global
)}
.
Run
(
c_m0_m1_m2_n_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_
thread
_buf_
,
c_
blk
_buf_
,
c_m0_m1_m2_n_global_desc
,
c_global_buf
,
c_m0_m1_m2_n_global_tensor_iterator_hacks
);
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
59462dca
...
...
@@ -791,24 +791,15 @@ struct XdlopsGemm
static_assert
(
KPerWave
%
KPerXdlops
==
0
,
"KPerWave cannot be divided by KPerXdlops"
);
vector_type
<
base_type
,
GetXdlopsInfo
().
GetNumCRegs
()
>
t
;
using
c_type
=
decltype
(
GetXdlopsInfo
().
GetCType
());
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m0
,
n0
));
static_for
<
0
,
KPerWave
,
KPerXdlops
>
{}([
&
](
auto
k
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m0
,
0
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
n0
,
0
));
t
.
template
AsType
<
c_type
>()(
Number
<
0
>
{})
=
p_c_thread
.
template
AsType
<
c_type
>()[
Number
<
c_offset
>
{}];
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
Number
<
a_offset
>
{}],
p_b_wave
[
Number
<
b_offset
>
{}],
t
);
p_c_thread
.
template
AsType
<
c_type
>()(
Number
<
c_offset
>
{})
=
t
.
template
AsType
<
c_type
>()[
Number
<
0
>
{}];
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
Number
<
a_offset
>
{}],
p_b_wave
[
Number
<
b_offset
>
{}],
p_c_thread
(
Number
<
c_offset
>
{}));
});
}
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
59462dca
...
...
@@ -106,7 +106,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
...
...
@@ -115,13 +115,13 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
index_t
MRepeat
=
1
;
constexpr
index_t
NRepeat
=
1
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment