Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
35267a40
Commit
35267a40
authored
Aug 31, 2023
by
Bartlomiej Wroblewski
Browse files
Review: Add blockwise doc, change function names to include dimension names
parent
9ca59788
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
26 deletions
+41
-26
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
+31
-15
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+10
-11
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
View file @
35267a40
...
@@ -10,6 +10,15 @@
...
@@ -10,6 +10,15 @@
namespace
ck
{
namespace
ck
{
/**
* Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each
* thread by sharing the data between threads in a lanegroup.
*
* In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are
* `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one.
* In total, the algorithm runs using
* `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves.
*/
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
...
@@ -69,20 +78,24 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
...
@@ -69,20 +78,24 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__device__
static
auto
CalculateAThreadOriginDataIndex
_M0_M1_M2_K
()
{
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
dpp_a_idx
=
dpp_gemm
.
CalculateAThreadOriginDataIndex
();
const
auto
dpp_a_idx
=
dpp_gemm
.
CalculateAThreadOriginDataIndex_K_M
();
return
make_tuple
(
0
,
waveId_m
,
dpp_a_idx
[
I1
],
KPerThread
*
dpp_a_idx
[
I0
]);
const
auto
dpp_a_idx_k
=
dpp_a_idx
[
I0
];
const
auto
dpp_a_idx_m
=
dpp_a_idx
[
I1
];
return
make_tuple
(
0
,
waveId_m
,
dpp_a_idx_m
,
KPerThread
*
dpp_a_idx_k
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__device__
static
auto
CalculateBThreadOriginDataIndex
_N0_N1_N2_K
()
{
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
dpp_b_idx
=
dpp_gemm
.
CalculateBThreadOriginDataIndex
();
const
auto
dpp_b_idx
=
dpp_gemm
.
CalculateBThreadOriginDataIndex_K_N
();
return
make_tuple
(
0
,
waveId_n
,
dpp_b_idx
[
I1
],
KPerThread
*
dpp_b_idx
[
I0
]);
const
auto
dpp_b_idx_k
=
dpp_b_idx
[
I0
];
const
auto
dpp_b_idx_n
=
dpp_b_idx
[
I1
];
return
make_tuple
(
0
,
waveId_n
,
dpp_b_idx_n
,
KPerThread
*
dpp_b_idx_k
);
}
}
template
<
index_t
m0
,
index_t
n0
>
template
<
index_t
m0
,
index_t
n0
>
...
@@ -91,7 +104,10 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
...
@@ -91,7 +104,10 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
dpp_gemm
.
GetBeginOfThreadBlk
();
const
auto
blk_idx
=
dpp_gemm
.
GetBeginOfThreadBlk
();
const
auto
blk_m_offset
=
blk_idx
[
I0
];
const
auto
blk_n_offset
=
blk_idx
[
I1
];
constexpr
auto
mrepeat_mwave_MPerDpp_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
mrepeat_mwave_MPerDpp_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerDpp
))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerDpp
))),
...
@@ -104,9 +120,9 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
...
@@ -104,9 +120,9 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_MPerDpp_to_m_adaptor
.
CalculateBottomIndex
(
const
index_t
c_thread_m
=
mrepeat_mwave_MPerDpp_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_
idx
[
I0
]
))[
I0
];
make_tuple
(
m0
,
waveId_m
,
blk_
m_offset
))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_NPerDpp_to_n_adaptor
.
CalculateBottomIndex
(
const
index_t
c_thread_n
=
nrepeat_nwave_NPerDpp_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_
idx
[
I1
]
))[
I0
];
make_tuple
(
n0
,
waveId_n
,
blk_
n_offset
))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
}
...
@@ -324,8 +340,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
...
@@ -324,8 +340,8 @@ struct BlockwiseGemmDpp_k0mk1_k0nk1_m0n0m1n1m2n2
B_K1
,
B_K1
,
B_K1
>
;
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
_M0_M1_M2_K
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
_N0_N1_N2_K
()};
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
View file @
35267a40
...
@@ -20,14 +20,14 @@ enum struct DppInstr
...
@@ -20,14 +20,14 @@ enum struct DppInstr
* Structure representing DPP GEMM executed by a single wavefront.
* Structure representing DPP GEMM executed by a single wavefront.
*
*
* Each structure instantiation must contain the following fields:
* Each structure instantiation must contain the following fields:
* - wave_size - number of threads that execute single DPP GEMM operation, usually equal to the
* - wave_size - number of threads that execute
a
single DPP GEMM operation, usually equal to the
* number of threads in a wavefront;
* number of threads in a wavefront;
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
* it's 8 in case of DPP8;
* it's 8 in case of DPP8;
* - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM
* - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM
* operation;
*
operation;
* - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM
* - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM
* operation;
*
operation;
* - m_per_lanegroup - size along M dimension that is processed by a single lanegroup;
* - m_per_lanegroup - size along M dimension that is processed by a single lanegroup;
* - n_per_lanegroup - size along N dimension that is processed by a single lanegroup;
* - n_per_lanegroup - size along N dimension that is processed by a single lanegroup;
* - m_per_thread - size along M dimension of the tile calculated by a single thread;
* - m_per_thread - size along M dimension of the tile calculated by a single thread;
...
@@ -254,16 +254,15 @@ struct DppGemm
...
@@ -254,16 +254,15 @@ struct DppGemm
return
make_tuple
(
m_dpp_idx
,
n_dpp_idx
);
return
make_tuple
(
m_dpp_idx
,
n_dpp_idx
);
}
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
_K_M
()
{
{
const
auto
laneId
=
get_thread_local_1d_id
();
const
auto
laneId
=
get_thread_local_1d_id
();
const
auto
wave_row
=
laneId
/
dpp_instr
.
n_per_wave
;
const
auto
wave_row
=
laneId
/
dpp_instr
.
n_per_wave
;
auto
m_idx
=
dpp_instr
.
m_per_thread
*
wave_row
+
GetLaneIdInLaneGroup
();
auto
m_idx
=
dpp_instr
.
m_per_thread
*
wave_row
+
GetLaneIdInLaneGroup
();
return
make_tuple
(
0
,
m_idx
%
dpp_instr
.
m_per_wave
);
return
make_tuple
(
0
,
m_idx
%
dpp_instr
.
m_per_wave
);
return
make_tuple
(
0
,
laneId
%
dpp_instr
.
m_per_lanegroup
);
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
_K_N
()
{
{
const
auto
laneId
=
get_thread_local_1d_id
();
const
auto
laneId
=
get_thread_local_1d_id
();
return
make_tuple
(
0
,
laneId
%
dpp_instr
.
n_per_wave
);
return
make_tuple
(
0
,
laneId
%
dpp_instr
.
n_per_wave
);
...
@@ -271,13 +270,13 @@ struct DppGemm
...
@@ -271,13 +270,13 @@ struct DppGemm
__device__
static
CIndex
GetBeginOfThreadBlk
()
__device__
static
CIndex
GetBeginOfThreadBlk
()
{
{
const
auto
dpp_idx
=
GetDppOpIdx
();
const
auto
dpp_
op_
idx
=
GetDppOpIdx
();
const
auto
m_dpp_idx
=
dpp_idx
[
I0
];
const
auto
m_dpp_
op_
idx
=
dpp_
op_
idx
[
I0
];
const
auto
n_dpp_idx
=
dpp_idx
[
I1
];
const
auto
n_dpp_
op_
idx
=
dpp_
op_
idx
[
I1
];
index_t
n_offset
=
n_dpp_idx
*
dpp_instr
.
n_per_lanegroup
+
GetLaneIdInLaneGroup
();
index_t
n_offset
=
n_dpp_
op_
idx
*
dpp_instr
.
n_per_lanegroup
+
GetLaneIdInLaneGroup
();
index_t
m_offset
=
m_dpp_idx
*
dpp_instr
.
m_per_lanegroup
;
index_t
m_offset
=
m_dpp_
op_
idx
*
dpp_instr
.
m_per_lanegroup
;
return
CIndex
{
m_offset
,
n_offset
};
return
CIndex
{
m_offset
,
n_offset
};
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment