Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8c84c0b1
Commit
8c84c0b1
authored
May 18, 2021
by
Jing Zhang
Browse files
add KReduction
parent
02bf2be0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
6 deletions
+46
-6
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+24
-2
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+22
-4
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
8c84c0b1
...
...
@@ -55,7 +55,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
return
make_tuple
(
0
,
waveId_m
*
MPerWave
+
laneId
);
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
)
*
xdlops_gemm
.
mfma_type
.
k_base
;
return
make_tuple
(
k_offset
,
m_offset
);
}
else
{
const
index_t
m_offset
=
waveId_m
*
MPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
m_offset
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
...
...
@@ -66,7 +77,18 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
return
make_tuple
(
0
,
waveId_n
*
NPerWave
+
laneId
);
if
constexpr
(
xdlops_gemm
.
IsKReduction
)
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
xdlops_gemm
.
GetBlkTd
(
laneId
);
const
index_t
k_offset
=
xdlops_gemm
.
GetBlkId
(
laneId
)
*
xdlops_gemm
.
mfma_type
.
k_base
;
return
make_tuple
(
k_offset
,
n_offset
);
}
else
{
const
index_t
n_offset
=
waveId_n
*
NPerWave
+
laneId
;
const
index_t
k_offset
=
0
;
return
make_tuple
(
k_offset
,
n_offset
);
}
}
template
<
index_t
AStride
=
MPerWave
,
index_t
BStride
=
NPerWave
>
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
8c84c0b1
...
...
@@ -535,6 +535,11 @@ struct xdlops_info
return
(
mfma_type
.
num_output_blks
==
1
)
&&
(
mfma_type
.
num_input_blks
>
1
);
}
static
constexpr
index_t
GetKPerXdlops
()
{
return
mfma_type
.
k_base
*
(
IsKReduction
()
?
mfma_type
.
num_input_blks
:
1
);
}
static
constexpr
auto
OutputVecType
=
OutputVecType_
{};
};
...
...
@@ -571,7 +576,7 @@ struct XdlopsGemm
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
wave_size
==
mfma_type
.
m
*
mfma_type
.
n
,
"num_regs_blk incorrect"
);
static_assert
(
mfma_type
.
k
%
mfma_type
.
k_base
==
0
,
"k
and
k
_
base
is inconsistent
!"
);
static_assert
(
mfma_type
.
k
%
mfma_type
.
k_base
==
0
,
"k
%
kbase
!= 0
!"
);
}
__device__
static
constexpr
index_t
GetRegSizePerXdlops
()
...
...
@@ -586,7 +591,9 @@ struct XdlopsGemm
is_same
<
data_type
,
ushort
>::
value
,
"base data_type must be float, half, ushort!"
);
static_for
<
0
,
KPerWave
,
mfma_type
.
k_base
>
{}([
&
](
auto
k_i
)
{
static_assert
(
KPerWave
%
KPerXdlops
==
0
,
"KPerWave cannot be divided by KPerXdlops"
);
static_for
<
0
,
KPerWave
,
KPerXdlops
>
{}([
&
](
auto
k_i
)
{
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
Number
<
k_i
>
{}],
p_b_wave
[
Number
<
k_i
>
{}],
p_c_thread
);
});
...
...
@@ -833,8 +840,19 @@ struct XdlopsGemm
static
constexpr
index_t
MPerXdlops
=
GetXdlopsInfo
().
MPerXdlops
;
static
constexpr
index_t
NPerXdlops
=
GetXdlopsInfo
().
NPerXdlops
;
static
constexpr
bool
IsKReduction
=
GetXdlopsInfo
().
IsKReduction
();
static
constexpr
bool
IsABroadcast
=
GetXdlopsInfo
().
IsABroadcast
();
static
constexpr
bool
IsKReduction
=
GetXdlopsInfo
().
IsKReduction
();
static
constexpr
bool
IsABroadcast
=
GetXdlopsInfo
().
IsABroadcast
();
static
constexpr
index_t
KPerXdlops
=
GetXdlopsInfo
().
GetKPerXdlops
();
static
constexpr
auto
GetBlkId
(
const
index_t
lane_id
)
{
return
lane_id
/
mfma_type
.
num_threads_blk
;
}
static
constexpr
auto
GetBlkTd
(
const
index_t
lane_id
)
{
return
lane_id
%
mfma_type
.
num_threads_blk
;
}
static
constexpr
auto
mfma_type
=
GetXdlopsInfo
().
mfma_type
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment