Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9bdad55b
Commit
9bdad55b
authored
May 16, 2021
by
Jing Zhang
Browse files
debugging
parent
7084b152
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
253 additions
and
426 deletions
+253
-426
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+10
-11
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+66
-78
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+85
-59
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+12
-198
composable_kernel/include/utility/amd_xdlops.hpp
composable_kernel/include/utility/amd_xdlops.hpp
+65
-64
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+10
-11
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+5
-5
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
9bdad55b
...
@@ -111,12 +111,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -111,12 +111,11 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{};
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{};
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{};
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{};
const
auto
out_
gemmm0_gemmm1_gemmn0_gemmn1
_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
out_
m0_m1_m2_n
_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM0
,
GemmM1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
4
,
2
,
4
)),
make_pass_through_transform
(
N
)),
make_unmerge_transform
(
make_tuple
(
GemmN0
,
GemmN1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}));
// out_gemm_block_cluster_desc
// out_gemm_block_cluster_desc
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
const
auto
out_gemm_block_cluster_desc
=
make_cluster_descriptor_v2
(
...
@@ -141,23 +140,23 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -141,23 +140,23 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global
// tensor hack for NKHW format
// tensor hack for NKHW format
constexpr
auto
out_
gemmm0_gemmm1_gemmn0_gemmn1
_global_iterator_hacks
=
constexpr
auto
out_
m0_m1_m2_n
_global_iterator_hacks
=
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
1
,
0
,
0
>
{}),
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}),
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
make_tuple
(
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{},
Sequence
<
0
,
0
,
0
,
0
,
0
>
{},
Sequence
<
0
,
0
,
2
,
0
,
0
>
{}));
Sequence
<
0
,
0
,
0
,
0
,
0
>
{}));
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
return
make_tuple
(
wei_gemmk_gemmm_global_desc
,
in_gemmk_gemmn_global_desc
,
in_gemmk_gemmn_global_desc
,
out_
gemmm0_gemmm1_gemmn0_gemmn1
_global_desc
,
out_
m0_m1_m2_n
_global_desc
,
out_gemm_block_cluster_desc
,
out_gemm_block_cluster_desc
,
wei_gemmk_gemmm_global_iterator_hacks
,
wei_gemmk_gemmm_global_iterator_hacks
,
in_gemmk_gemmn_global_iterator_hacks
,
in_gemmk_gemmn_global_iterator_hacks
,
out_
gemmm0_gemmm1_gemmn0_gemmn1
_global_iterator_hacks
,
out_
m0_m1_m2_n
_global_iterator_hacks
,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
wei_gemmk_gemmm_global_move_slice_window_iterator_hacks
,
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
in_gemmk_gemmn_global_move_slice_window_iterator_hacks
);
}
}
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
9bdad55b
...
@@ -11,35 +11,24 @@ namespace ck {
...
@@ -11,35 +11,24 @@ namespace ck {
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
class
ABlockDesc
,
class
ABlockDesc
,
class
BBlockDesc
,
class
BBlockDesc
,
index_t
GemmMPerWave
,
index_t
MPerWave
,
index_t
GemmNPerWave
,
index_t
NPerWave
,
index_t
GemmKPerWave
,
index_t
KPerWave
,
index_t
GemmMWaves
,
index_t
MWaves
,
index_t
GemmNWaves
,
index_t
NWaves
>
index_t
GemmDataPerReadA
,
// \todo unused parameter, remove
index_t
GemmDataPerReadB
// \todo unused parameter, remove
>
struct
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
struct
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{
{
struct
MatrixIndex
{
using
CIndex
=
MultiIndex
<
2
>
;
index_t
row
;
index_t
col
;
};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
XdlopsGemm
=
static
constexpr
auto
XdlopsGemm
=
XdlopsGemm_t
<
float
,
MPerWave
,
NPerWave
,
KPerWave
>
{};
XdlopsGemm_t
<
float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
index_t
mMyWaveOffsetA
;
index_t
mMyWaveOffsetB
;
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
WaveSize
=
64
;
...
@@ -55,7 +44,45 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -55,7 +44,45 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
return
XdlopsGemm
.
GetOutputLayout
().
GetBlkSize
();
return
XdlopsGemm
.
GetOutputLayout
().
GetBlkSize
();
}
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
return
make_tuple
(
0
,
waveId_m
*
MPerWave
+
laneId
);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
const
index_t
waveId
=
thread_id
/
WaveSize
;
const
index_t
laneId
=
thread_id
%
WaveSize
;
const
index_t
waveId_m
=
waveId
/
NWaves
;
const
index_t
waveId_n
=
waveId
%
NWaves
;
return
make_tuple
(
0
,
waveId_n
*
NPerWave
+
laneId
);
}
template
<
index_t
AStride
=
MPerWave
,
index_t
BStride
=
NPerWave
>
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
index_t
blk_i
)
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
XdlopsGemm
.
GetBeginOfThreadBlk
(
blk_i
);
const
index_t
row
=
(
waveId
/
NWaves
)
*
AStride
+
thread_mtx_on_blk
.
row
;
const
index_t
col
=
(
waveId
%
NWaves
)
*
BStride
+
thread_mtx_on_blk
.
col
;
return
CIndex
{
row
,
col
};
}
__device__
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
()
__device__
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
()
:
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()},
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()}
{
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
"wrong! Desc should be known at compile-time"
);
...
@@ -66,18 +93,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -66,18 +93,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
constexpr
index_t
M
=
ABlockDesc
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
M
=
ABlockDesc
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BBlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
N
=
BBlockDesc
{}.
GetLength
(
I1
);
static_assert
(
GemmMPerWave
*
GemmMWaves
==
M
,
"GemmMWaves * GemmMPerWave != M"
);
static_assert
(
MPerWave
*
MWaves
==
M
,
"GemmMWaves * MPerWave != M"
);
static_assert
(
GemmNPerWave
*
GemmNWaves
==
N
,
"GemmNWaves * GemmNPerWave != N"
);
static_assert
(
NPerWave
*
NWaves
==
N
,
"GemmNWaves * NPerWave != N"
);
static_assert
(
BlockSize
==
GemmMWaves
*
GemmNWaves
*
WaveSize
,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize
\n
"
);
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
const
index_t
waveId_m
=
waveId
/
GemmNWaves
;
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
const
index_t
waveId_n
=
waveId
%
GemmNWaves
;
mMyWaveOffsetA
=
waveId_m
*
GemmMPerWave
;
mMyWaveOffsetB
=
waveId_n
*
GemmNPerWave
;
}
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
...
@@ -90,73 +110,41 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -90,73 +110,41 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
auto
b_thread_buf
=
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
#if 0
constexpr
index_t
KPerBlock
=
ABlockDesc
{}.
GetLength
(
I0
);
constexpr auto threadwise_gemm =
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
FloatB,
FloatC,
decltype(a_thread_desc_),
decltype(b_thread_desc_),
CThreadDesc,
Sequence<GemmKPerWave>,
Sequence<M0_, M1PerThread>,
Sequence<N0_, N1PerThread>>{};
constexpr index_t K = ABlockDesc{}.GetLength(I0);
static_for
<
0
,
KPerBlock
,
KPerWave
>
{}([
&
](
auto
k
)
{
static_for<0, K, GemmKPerWave>{}([&](auto k) {
a_thread_copy_
.
Run
(
ABlockDesc
{},
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple(k,
I0,
I0),
make_tuple
(
k
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple(I0,
I0,
I0),
make_tuple
(
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
b_thread_copy_
.
Run
(
BBlockDesc
{},
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple(k,
I0,
I0),
make_tuple
(
k
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple(I0,
I0,
I0),
make_tuple
(
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
threadwise_gemm.Run(a_thread_buf,
XdlopsGemm
.
template
Run
(
a_thread_buf
,
b_thread_buf
,
c_thread_buf
);
make_tuple(I0, I0, I0),
b_thread_buf,
make_tuple(I0, I0, I0),
c_thread_buf,
make_tuple(I0, I0, I0, I0));
});
});
#endif
}
template
<
index_t
AStride
=
GemmMPerWave
,
index_t
BStride
=
GemmNPerWave
>
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
i
)
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
XdlopsGemm
.
GetBeginOfThreadBlk
(
i
);
const
index_t
col
=
(
waveId
%
GemmNWaves
)
*
BStride
+
thread_mtx_on_blk
.
col
;
const
index_t
row
=
(
waveId
/
GemmNWaves
)
*
AStride
+
thread_mtx_on_blk
.
row
;
return
MatrixIndex
{
row
,
col
};
}
}
private:
private:
// A[K, M]
// A[K, M]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
a_thread_desc_
=
make_tuple
(
Number
<
Gemm
KPerWave
>
{},
Number
<
1
>
{}));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{},
Number
<
1
>
{}));
// B[K, N]
// B[K, N]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
b_thread_desc_
=
make_tuple
(
Number
<
Gemm
KPerWave
>
{},
Number
<
1
>
{}));
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerWave
>
{},
Number
<
1
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
FloatA
,
ABlockDesc
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
Gemm
KPerWave
,
1
>
,
Sequence
<
KPerWave
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
...
@@ -166,14 +154,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -166,14 +154,14 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
FloatB
,
FloatB
,
BBlockDesc
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
Gemm
KPerWave
,
1
>
,
Sequence
<
KPerWave
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
1
,
1
,
1
>
;
1
>
;
//
AThreadCopy a_thread_copy_;
AThreadCopy
a_thread_copy_
;
//
BThreadCopy b_thread_copy_;
BThreadCopy
b_thread_copy_
;
};
};
}
// namespace ck
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
9bdad55b
...
@@ -32,7 +32,7 @@ __global__ void
...
@@ -32,7 +32,7 @@ __global__ void
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
a_k_m_global_desc
,
const
AGlobalDesc
a_k_m_global_desc
,
const
BGlobalDesc
b_k_n_global_desc
,
const
BGlobalDesc
b_k_n_global_desc
,
const
CGlobalDesc
c_m0_m1_
n0
_n
1
_global_desc
,
const
CGlobalDesc
c_m0_m1_
m2
_n_global_desc
,
const
CBlockClusterDesc
c_block_cluster_desc
)
const
CBlockClusterDesc
c_block_cluster_desc
)
{
{
GridwiseGemm
::
Run
(
p_a_global
,
GridwiseGemm
::
Run
(
p_a_global
,
...
@@ -40,7 +40,7 @@ __global__ void
...
@@ -40,7 +40,7 @@ __global__ void
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
a_k_m_global_desc
,
b_k_n_global_desc
,
b_k_n_global_desc
,
c_m0_m1_
n0
_n
1
_global_desc
,
c_m0_m1_
m2
_n_global_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
@@ -68,7 +68,7 @@ __global__ void
...
@@ -68,7 +68,7 @@ __global__ void
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
void
__CONSTANT__
*
p_a_k_m_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_b_k_n_global_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_
n0
_n
1
_global_desc
,
const
void
__CONSTANT__
*
p_c_m0_m1_
m2
_n_global_desc
,
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
const
void
__CONSTANT__
*
p_c_block_cluster_desc
)
{
{
// first cast void __CONSTANT__ void* to void*
// first cast void __CONSTANT__ void* to void*
...
@@ -78,8 +78,8 @@ __global__ void
...
@@ -78,8 +78,8 @@ __global__ void
*
reinterpret_cast
<
const
AGlobalDesc
*>
((
const
void
*
)
p_a_k_m_global_desc
);
*
reinterpret_cast
<
const
AGlobalDesc
*>
((
const
void
*
)
p_a_k_m_global_desc
);
const
auto
b_k_n_global_desc
=
const
auto
b_k_n_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
((
const
void
*
)
p_b_k_n_global_desc
);
*
reinterpret_cast
<
const
BGlobalDesc
*>
((
const
void
*
)
p_b_k_n_global_desc
);
const
auto
c_m0_m1_
n0
_n
1
_global_desc
=
const
auto
c_m0_m1_
m2
_n_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
((
const
void
*
)
p_c_m0_m1_
n0
_n
1
_global_desc
);
*
reinterpret_cast
<
const
CGlobalDesc
*>
((
const
void
*
)
p_c_m0_m1_
m2
_n_global_desc
);
const
auto
c_block_cluster_desc
=
const
auto
c_block_cluster_desc
=
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
*
reinterpret_cast
<
const
CBlockClusterDesc
*>
((
const
void
*
)
p_c_block_cluster_desc
);
...
@@ -89,7 +89,7 @@ __global__ void
...
@@ -89,7 +89,7 @@ __global__ void
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
a_k_m_global_desc
,
b_k_n_global_desc
,
b_k_n_global_desc
,
c_m0_m1_
n0
_n
1
_global_desc
,
c_m0_m1_
m2
_n_global_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
@@ -174,7 +174,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -174,7 +174,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_
n0
_n
1
_global_desc
,
const
CGlobalDesc
&
c_m0_m1_
m2
_n_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
...
@@ -190,7 +190,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -190,7 +190,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_b_global
,
b_k_n_global_desc
.
GetElementSpaceSize
());
p_b_global
,
b_k_n_global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpace
::
Global
>
(
p_c_global
,
c_m0_m1_
n0
_n
1
_global_desc
.
GetElementSpaceSize
());
p_c_global
,
c_m0_m1_
m2
_n_global_desc
.
GetElementSpaceSize
());
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
const
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
);
...
@@ -309,23 +309,20 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -309,23 +309,20 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
//
constexpr auto c_m0_m1_n0_n1_thread_desc =
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
//
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
Number
<
MRepeat
>
{},
Number
<
MPerThread
>
{},
Number
<
NRepeat
>
{},
Number
<
NPerThread
>
{}));
//
Number<MRepeat>{}, Number<MPerThread>{}, Number<NRepeat>{}, Number<NPerThread>{}));
const
auto
blockwise_gemm
=
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
<
BlockSize
,
const
auto
blockwise_gemm
=
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n_block_desc
),
64
,
// MPerWave,
64
,
// MPerWave,
64
,
// NPerWave,
64
,
// NPerWave,
KPerBlock
,
1
,
// KPerWave,
2
,
// MWaves,
1
,
// MWaves,
2
,
// NWaves,
1
// NWaves,
1
,
// GemmDataPerReadM,
1
// GemmDataPerReadN
>
{};
>
{};
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -339,13 +336,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -339,13 +336,15 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatAcc
>
(
// auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
// c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize());
// ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
// decltype(c_m0_m1_n0_n1_thread_desc),
// Sequence<MRepeat, MPerThread, NRepeat, NPerThread>>{}
//.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0});
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
vector_type
<
float
,
64
>
c_thread_buf
;
decltype
(
c_m0_m1_n0_n1_thread_desc
),
Sequence
<
MRepeat
,
MPerThread
,
NRepeat
,
NPerThread
>>
{}
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
...
@@ -474,43 +473,70 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -474,43 +473,70 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
}
}
#if 0
// output: register to global memory
// output: register to global memory
{
{
constexpr auto M1 = Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{};
constexpr auto N1 = Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{};
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
StaticBuffer
<
AddressSpace
::
Vgpr
,
float
,
64
>
c_thread_buf_
;
constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{};
static_for
<
0
,
64
,
1
>
{}(
[
&
](
auto
i
)
{
c_thread_buf_
(
i
)
=
c_thread_buf
.
template
AsType
<
float
>()[
i
];
});
constexpr
auto
OutputLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
K0
=
OutputLayout
.
M1
();
constexpr
index_t
K1
=
OutputLayout
.
N1
();
constexpr
index_t
K2
=
OutputLayout
.
M0
();
static_assert
(
K0
==
4
&&
K1
==
2
&&
K2
==
4
,
""
);
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
K0
>
{},
Number
<
1
>
{},
Number
<
K2
>
{},
Number
<
1
>
{}));
constexpr
index_t
BlkSize
=
OutputLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
OutputLayout
.
GetNumBlks
();
const auto c_thread_data_idx_on_block =
static_assert
(
BlkSize
==
16
&&
NumBlks
==
4
,
""
);
blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id());
// force unrolling the output loop to get ride of scratches
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
i
);
const
index_t
k_thread_data_on_global
=
m_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I0
];
const
index_t
b_thread_data_on_global
=
n_block_data_idx_on_global
+
c_thread_mtx_on_block
[
I1
];
constexpr
auto
c_m0_m1_n0_n1_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
,
decltype(c_m0_m1_
n0
_n
1
_thread_desc),
decltype
(
c_m0_m1_
m2
_n_thread_desc
),
decltype(c_m0_m1_
n0
_n
1
_global_desc),
decltype
(
c_m0_m1_
m2
_n_global_desc
),
Sequence<
MRepeat, MPerThread, NRepeat, NPerThread
>,
Sequence
<
K0
,
1
,
K2
,
1
>
,
CThreadTransferSrcDstAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
//
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
3
,
//
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
1
,
//
CThreadTransferDstScalarPerVector,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true>{
true
>
{
c_m0_m1_m2_n_global_desc
,
c_m0_m1_n0_n1_global_desc,
make_multi_index
(
k_thread_data_on_global
/
(
K2
*
K1
),
make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0],
k_thread_data_on_global
%
(
K2
*
K1
)
/
K2
,
c_thread_data_idx_on_block[I1],
k_thread_data_on_global
%
K2
,
n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2],
b_thread_data_on_global
)}
c_thread_data_idx_on_block[I3])}
.
Run
(
c_m0_m1_m2_n_thread_desc
,
.Run(c_m0_m1_n0_n1_thread_desc,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf,
c_thread_buf
_
,
c_m0_m1_
n0
_n
1
_global_desc,
c_m0_m1_
m2
_n_global_desc
,
c_global_buf
,
c_global_buf
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
}
}
#endif
}
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
...
@@ -519,7 +545,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -519,7 +545,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
AGlobalDesc
&
a_k_m_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
CGlobalDesc
&
c_m0_m1_
n0
_n
1
_global_desc
,
const
CGlobalDesc
&
c_m0_m1_
m2
_n_global_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
const
CBlockClusterDesc
&
c_block_cluster_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
...
@@ -533,7 +559,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -533,7 +559,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
p_c_global
,
p_c_global
,
a_k_m_global_desc
,
a_k_m_global_desc
,
b_k_n_global_desc
,
b_k_n_global_desc
,
c_m0_m1_
n0
_n
1
_global_desc
,
c_m0_m1_
m2
_n_global_desc
,
c_block_cluster_desc
,
c_block_cluster_desc
,
p_shared_block
,
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
9bdad55b
...
@@ -50,20 +50,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
...
@@ -50,20 +50,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
cycles
=
64
;
static
constexpr
index_t
k_base
=
1
;
static
constexpr
index_t
k_base
=
1
;
template
<
index_t
MPerXdlops
,
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
index_t
NPerXdlops
,
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
index_t
AStride
,
index_t
BStride
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
run
(
const
FloatA
*
a
,
const
FloatB
*
b
,
FloatC
reg_c
)
const
{
{
const
auto
p_a
=
reinterpret_cast
<
const
float
*>
(
a
);
return
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
>::
run
(
a
,
b
,
reg_c
);
const
auto
p_b
=
reinterpret_cast
<
const
float
*>
(
b
);
return
intrin_mfma_f32_32x32x1f32
<
MPerXdlops
,
NPerXdlops
,
AStride
,
BStride
>::
run
(
p_a
,
p_b
,
reg_c
);
}
}
};
};
...
@@ -557,11 +547,7 @@ struct xdlops_info
...
@@ -557,11 +547,7 @@ struct xdlops_info
static
constexpr
auto
OutputVecType
=
OutputVecType_
{};
static
constexpr
auto
OutputVecType
=
OutputVecType_
{};
};
};
template
<
class
data_type
,
template
<
class
data_type
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
>
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
>
struct
XdlopsGemm_t
struct
XdlopsGemm_t
{
{
struct
MatrixIndex
struct
MatrixIndex
...
@@ -585,8 +571,6 @@ struct XdlopsGemm_t
...
@@ -585,8 +571,6 @@ struct XdlopsGemm_t
MPerXdlops
==
64
,
MPerXdlops
==
64
,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"
);
static_assert
(
GemmDataPerReadA
==
1
&&
GemmDataPerReadB
==
1
,
"GemmDataPerReadA/B != 1"
);
static_assert
(
mfma_type
.
num_threads_blk
==
mfma_type
.
n
,
"n != num_threads_blk"
);
static_assert
(
mfma_type
.
num_threads_blk
==
mfma_type
.
n
,
"n != num_threads_blk"
);
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
num_input_blks
==
mfma_type
.
m
,
static_assert
(
mfma_type
.
num_regs_blk
*
mfma_type
.
num_input_blks
==
mfma_type
.
m
,
"m != num_input_blks * num_regs_blk"
);
"m != num_input_blks * num_regs_blk"
);
...
@@ -604,187 +588,17 @@ struct XdlopsGemm_t
...
@@ -604,187 +588,17 @@ struct XdlopsGemm_t
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
return
MPerXdlops
*
NPerXdlops
/
mfma_type
.
wave_size
;
}
}
#if CK_USE_AMD_XDLOPS_EMULATE
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
// emulate xdlops
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
XdlopsEmulate
(
const
FloatA
*
const
__restrict__
p_a_wave
,
const
FloatB
*
const
__restrict__
p_b_wave
,
FloatC
p_c_thread
)
const
{
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
// K reduction
static_if
<
IsKReduction
>
{}([
&
](
auto
)
{
for
(
index_t
k
=
0
;
k
<
K
;
k
+=
mfma_type
.
num_input_blks
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
(
k
+
n
)
*
M
;
index_t
b_off
=
(
k
+
n
)
*
N
;
index_t
c_off
=
0
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
})
.
Else
([
&
](
auto
)
{
static_if
<
IsABroadcast
>
{}([
&
](
auto
)
{
for
(
index_t
m_i
=
0
;
m_i
<
MRepeats
;
++
m_i
)
{
for
(
index_t
n_i
=
0
;
n_i
<
NRepeats
;
++
n_i
)
{
// ABroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
MPerXdlops
/
mfma_type
.
m
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
index_t
a_off
=
k
*
M
+
b
*
mfma_type
.
m
+
MPerXdlops
*
m_i
;
index_t
b_off
=
k
*
N
+
n
*
mfma_type
.
num_threads_blk
+
NPerXdlops
*
n_i
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
+
(
NRepeats
*
m_i
+
n_i
)
*
GetRegSizePerXdlops
();
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
}
}
})
.
Else
([
&
](
auto
)
{
// BBroadcast
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
for
(
index_t
b
=
0
;
b
<
NPerXdlops
/
mfma_type
.
n
;
++
b
)
{
for
(
index_t
n
=
0
;
n
<
mfma_type
.
num_input_blks
;
++
n
)
{
{
index_t
a_off
=
k
*
M
+
n
*
mfma_type
.
m
;
index_t
b_off
=
k
*
N
+
b
*
mfma_type
.
n
;
index_t
c_off
=
n
*
mfma_type
.
num_regs_blk
+
b
*
mfma_type
.
num_regs_xdlops
;
for
(
index_t
m
=
0
;
m
<
mfma_type
.
num_regs_blk
;
++
m
)
{
index_t
aindex
=
m
%
mfma_type
.
group_size
+
blk_id
*
mfma_type
.
group_size
+
m
/
mfma_type
.
group_size
*
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
p_c_thread
.
n
[
m
+
c_off
]
+=
inner_product_with_conversion
<
float
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
});
});
return
p_c_thread
;
}
#endif
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
FloatC
Run
(
const
FloatA
*
const
__restrict__
p_a_wave
,
const
FloatB
*
const
__restrict__
p_b_wave
,
FloatC
p_c_thread
)
const
{
static_assert
(
is_same
<
FloatA
,
FloatB
>::
value
,
"FloatA != FloatB"
);
static_assert
(
is_same
<
data_type
,
float
>::
value
||
is_same
<
data_type
,
half_t
>::
value
||
static_assert
(
is_same
<
data_type
,
float
>::
value
||
is_same
<
data_type
,
half_t
>::
value
||
is_same
<
data_type
,
ushort
>::
value
,
is_same
<
data_type
,
ushort
>::
value
,
"base data_type must be float, half, ushort!"
);
"base data_type must be float, half, ushort!"
);
#if CK_USE_AMD_XDLOPS_EMULATE
static_for
<
0
,
KPerWave
,
mfma_type
.
k_base
>
{}([
&
](
auto
k_i
)
{
p_c_thread
=
XdlopsEmulate
<
M
,
N
,
K
>
(
p_a_wave
,
p_b_wave
,
p_c_thread
);
mfma_type
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
#else
p_a_wave
[
Number
<
k_i
>
{}],
p_b_wave
[
Number
<
k_i
>
{}],
p_c_thread
);
constexpr
index_t
KPACT
=
sizeof
(
FloatA
)
/
sizeof
(
data_type
);
static_assert
(
KPACT
%
mfma_type
.
k_base
==
0
,
"wrong! KPACT is not supported by mfma"
);
constexpr
index_t
KRepeats
=
KPACT
/
mfma_type
.
k_base
;
static_assert
(
!
IsKReduction
||
K
%
mfma_type
.
num_input_blks
==
0
,
"K cannot divided by mfma_type.num_input_blks!"
);
constexpr
index_t
KPerThread
=
IsKReduction
?
K
/
mfma_type
.
num_input_blks
:
K
;
static_assert
(
!
IsKReduction
||
(
MRepeats
==
1
&&
NRepeats
==
1
),
"KReduction does not support M/N Repeats!"
);
FloatA
a
[
KPerThread
*
MRepeats
];
FloatB
b
[
KPerThread
*
NRepeats
];
auto
pa
=
reinterpret_cast
<
const
data_type
*>
(
&
a
);
auto
pb
=
reinterpret_cast
<
const
data_type
*>
(
&
b
);
constexpr
index_t
AStride
=
KPerThread
*
KRepeats
;
constexpr
index_t
BStride
=
KPerThread
*
KRepeats
;
const
index_t
laneId
=
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
static_if
<!
IsKReduction
>
{}([
&
](
auto
)
{
for
(
index_t
m_i
=
0
;
m_i
<
MRepeats
;
++
m_i
)
for
(
index_t
k_i
=
0
;
k_i
<
KPerThread
;
++
k_i
)
a
[
k_i
+
m_i
*
KPerThread
]
=
p_a_wave
[
k_i
*
M
+
laneId
+
MPerXdlops
*
m_i
];
for
(
index_t
n_i
=
0
;
n_i
<
NRepeats
;
++
n_i
)
for
(
index_t
k_i
=
0
;
k_i
<
KPerThread
;
++
k_i
)
b
[
k_i
+
n_i
*
KPerThread
]
=
p_b_wave
[
k_i
*
N
+
laneId
+
NPerXdlops
*
n_i
];
})
.
Else
([
&
](
auto
)
{
const
index_t
blk_id
=
laneId
/
mfma_type
.
num_threads_blk
;
const
index_t
blk_td
=
laneId
%
mfma_type
.
num_threads_blk
;
for
(
index_t
k_i
=
0
;
k_i
<
KPerThread
;
++
k_i
)
{
a
[
k_i
]
=
p_a_wave
[(
k_i
*
mfma_type
.
num_input_blks
+
blk_id
)
*
M
+
blk_td
];
b
[
k_i
]
=
p_b_wave
[(
k_i
*
mfma_type
.
num_input_blks
+
blk_id
)
*
N
+
blk_td
];
}
});
});
#if CK_WORKAROUND_SWDEV_229564
#pragma unroll
#endif
for
(
index_t
k_i
=
0
;
k_i
<
KPerThread
*
KRepeats
;
++
k_i
)
{
p_c_thread
=
mfma_type
.
template
run
<
MPerXdlops
*
MRepeats
,
NPerXdlops
*
NRepeats
,
AStride
,
BStride
>(
&
pa
[
k_i
*
mfma_type
.
k_base
],
&
pb
[
k_i
*
mfma_type
.
k_base
],
p_c_thread
);
}
#endif
return
p_c_thread
;
}
}
__device__
static
MatrixIndex
GetBeginOfThreadBlk
(
index_t
i
)
__device__
static
MatrixIndex
GetBeginOfThreadBlk
(
index_t
i
)
...
@@ -821,8 +635,8 @@ struct XdlopsGemm_t
...
@@ -821,8 +635,8 @@ struct XdlopsGemm_t
}
}
template
<
class
data_type_
=
data_type
,
template
<
class
data_type_
=
data_type
,
index_t
MPerWave_
=
Gemm
MPerWave
,
index_t
MPerWave_
=
MPerWave
,
index_t
NPerWave_
=
Gemm
NPerWave
>
index_t
NPerWave_
=
NPerWave
>
static
constexpr
auto
GetXdlopsInfo
();
static
constexpr
auto
GetXdlopsInfo
();
template
<
>
template
<
>
...
...
composable_kernel/include/utility/amd_xdlops.hpp
View file @
9bdad55b
...
@@ -198,78 +198,79 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
...
@@ -198,78 +198,79 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
extern
"C"
__device__
float4_t
llvm_intrin_amdgcn_mfma_f32_4x4x2bf16
(
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
ushort2_t
,
ushort2_t
,
float4_t
,
int
,
int
,
int
)
__asm
(
"llvm.amdgcn.mfma.f32.4x4x2bf16"
);
template
<
index_t
MPerWave
,
index_t
NPerWave
,
index_t
AStride
,
index_t
BStride
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x1f32
;
struct
intrin_mfma_f32_32x32x1f32
;
template
<
index_t
AStride
,
index_t
BStride
>
// template <index_t AStride, index_t BStride>
struct
intrin_mfma_f32_32x32x1f32
<
128
,
64
,
AStride
,
BStride
>
// struct intrin_mfma_f32_32x32x1f32<128, 64, AStride, BStride>
{
//{
__device__
static
c_vec32_4_t
::
VecType
//__device__ static c_vec32_4_t::VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
// run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
{
//{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
// reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c
.
s
.
z
=
// reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c
.
s
.
w
=
// reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
AStride
],
reg_b
[
0
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return
reg_c
;
// return reg_c;
}
//}
};
//};
template
<
index_t
AStride
,
index_t
BStride
>
// template <index_t AStride, index_t BStride>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
128
,
AStride
,
BStride
>
// struct intrin_mfma_f32_32x32x1f32<64, 128, AStride, BStride>
{
//{
__device__
static
c_vec32_4_t
::
VecType
//__device__ static c_vec32_4_t::VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_4_t
::
VecType
reg_c
)
// run(const float* reg_a, const float* reg_b, c_vec32_4_t::VecType reg_c)
{
//{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
// reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c
.
s
.
z
=
// reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
z
,
1
,
0
,
0
);
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c
.
s
.
w
=
// reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
BStride
],
reg_c
.
s
.
w
,
1
,
1
,
0
);
// llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return
reg_c
;
// return reg_c;
}
//}
};
//};
template
<
index_t
AStride
,
index_t
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
,
AStride
,
BStride
>
{
__device__
static
c_vec32_2_t
::
VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_2_t
::
VecType
reg_c
)
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
reg_c
.
s
.
y
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
y
,
1
,
1
,
0
);
return
reg_c
;
}
};
template
<
index_t
AStride
,
index_t
BStride
>
template
<
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
32
,
AStride
,
BStride
>
struct
intrin_mfma_f32_32x32x1f32
<
64
,
64
>
{
{
__device__
static
c_vec32_1_t
::
VecType
__device__
static
void
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_
vec
32_1_t
::
VecType
reg_c
)
run
(
const
float
&
reg_a
,
const
float
&
reg_b
,
vec
tor_type
<
float
,
64
>&
reg_c
)
{
{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
0
,
0
,
1
);
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
0
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
return
reg_c
;
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
0
>
{}],
1
,
0
,
0
);
reg_c
.
template
AsType
<
float32_t
>()(
Number
<
1
>
{})
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float32_t
>()[
Number
<
1
>
{}],
1
,
1
,
0
);
}
}
};
};
template
<
index_t
AStride
,
index_t
BStride
>
// template <index_t AStride, index_t BStride>
struct
intrin_mfma_f32_32x32x1f32
<
32
,
64
,
AStride
,
BStride
>
// struct intrin_mfma_f32_32x32x1f32<64, 32, AStride, BStride>
{
//{
__device__
static
c_vec32_1_t
::
VecType
//__device__ static c_vec32_1_t::VecType
run
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec32_1_t
::
VecType
reg_c
)
// run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
{
//{
reg_c
.
s
.
x
=
llvm_intrin_amdgcn_mfma_f32_32x32x1f32
(
reg_a
[
0
],
reg_b
[
0
],
reg_c
.
s
.
x
,
1
,
0
,
0
);
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return
reg_c
;
// return reg_c;
}
//}
};
//};
// template <index_t AStride, index_t BStride>
// struct intrin_mfma_f32_32x32x1f32<32, 64, AStride, BStride>
//{
//__device__ static c_vec32_1_t::VecType
// run(const float* reg_a, const float* reg_b, c_vec32_1_t::VecType reg_c)
//{
// reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
// return reg_c;
//}
//};
__device__
c_vec16_1_t
::
VecType
__device__
c_vec16_1_t
::
VecType
intrin_mfma_f32_32x32x2f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
intrin_mfma_f32_32x32x2f32
(
const
float
*
reg_a
,
const
float
*
reg_b
,
c_vec16_1_t
::
VecType
reg_c
)
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
9bdad55b
...
@@ -77,12 +77,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -77,12 +77,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
const
auto
in_right_pads
=
sequence_to_tuple_of_number
(
InRightPads
{});
#endif
#endif
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 4x1
// b thread copy 4x1
constexpr
index_t
BlockSize
=
25
6
;
constexpr
index_t
BlockSize
=
6
4
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
...
@@ -91,17 +90,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -91,17 +90,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
32
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
...
...
driver/src/conv_driver.cpp
View file @
9bdad55b
...
@@ -25,11 +25,11 @@ int main(int argc, char* argv[])
...
@@ -25,11 +25,11 @@ int main(int argc, char* argv[])
using
namespace
ck
;
using
namespace
ck
;
#if 1
#if 1
constexpr
index_t
N
=
8
;
constexpr
index_t
N
=
4
;
constexpr
index_t
C
=
16
;
constexpr
index_t
C
=
16
;
constexpr
index_t
HI
=
4
;
constexpr
index_t
HI
=
4
;
constexpr
index_t
WI
=
4
;
constexpr
index_t
WI
=
4
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
@@ -688,7 +688,7 @@ int main(int argc, char* argv[])
...
@@ -688,7 +688,7 @@ int main(int argc, char* argv[])
#elif
0
#elif
0
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
#elif
0
#elif
1
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
in_nchw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_1
{},
num_thread
);
#elif 1
#elif 1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment