Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0bb08f4b
Commit
0bb08f4b
authored
May 10, 2023
by
aska-0096
Browse files
1. Enable 2-stage global Prefetch ( May cause VGPR spilling)
2. Enable FP16 accumulator blockwise_gemm
parent
716860e3
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
101 additions
and
73 deletions
+101
-73
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+8
-8
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+4
-0
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
+12
-12
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+7
-7
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+21
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+6
-3
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+20
-21
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+3
-3
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
0bb08f4b
...
@@ -35,24 +35,24 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -35,24 +35,24 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp
,
BElementOp
,
CElementOp
,
CElementOp
,
GemmDefault
,
GemmDefault
,
1
,
// Prefetch stage
2
,
// Prefetch stage
256
,
// BlockSize
128
,
// BlockSize
128
,
// MPerBlock
64
,
// MPerBlock
256
,
// NPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
64
,
// KPerBlock
8
,
// K1
8
,
// K1
16
,
// MPerWmma
16
,
// MPerWmma
16
,
// NPerWmma
16
,
// NPerWmma
4
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
64
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
8
,
8
,
8
,
true
,
true
,
S
<
4
,
64
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
...
@@ -61,7 +61,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -61,7 +61,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
32
,
1
,
8
>
,
S
<
1
,
32
,
1
,
4
>
,
8
>
;
8
>
;
// clang-format on
// clang-format on
...
...
example/01_gemm/run_gemm_example.inc
View file @
0bb08f4b
...
@@ -47,6 +47,10 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -47,6 +47,10 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
break
;
break
;
case
5
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
2.
f
,
2.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
2.
f
,
2.
f
}(
b_k_n
);
break
;
default
:
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
...
...
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
View file @
0bb08f4b
...
@@ -80,34 +80,34 @@ using DeviceOpInstance =
...
@@ -80,34 +80,34 @@ using DeviceOpInstance =
BElementOp
,
BElementOp
,
CDEElementOp
,
CDEElementOp
,
GemmSpec
,
GemmSpec
,
1
,
2
,
64
,
128
,
32
,
64
,
64
,
128
,
64
,
64
,
8
,
8
,
16
,
16
,
16
,
16
,
2
,
2
,
2
,
4
,
S
<
4
,
16
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
8
,
4
,
8
,
true
,
true
,
S
<
4
,
16
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
4
,
8
,
4
,
8
,
true
,
true
,
1
,
1
,
1
,
1
,
S
<
1
,
2
,
1
,
32
>
,
S
<
1
,
3
2
,
1
,
4
>
,
1
>
;
8
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
View file @
0bb08f4b
...
@@ -67,24 +67,24 @@ using DeviceOpInstanceKKNN =
...
@@ -67,24 +67,24 @@ using DeviceOpInstanceKKNN =
ASpec
,
ASpec
,
BSpec
,
BSpec
,
DESpec
,
DESpec
,
1
,
2
,
256
,
128
,
128
,
64
,
128
,
128
,
32
,
32
,
8
,
8
,
16
,
16
,
16
,
16
,
8
,
2
,
1
,
4
,
S
<
4
,
64
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
8
,
8
,
8
,
true
,
true
,
S
<
4
,
64
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
...
@@ -93,7 +93,7 @@ using DeviceOpInstanceKKNN =
...
@@ -93,7 +93,7 @@ using DeviceOpInstanceKKNN =
true
,
true
,
1
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
1
,
32
,
1
,
4
>
,
8
>
;
8
>
;
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
0bb08f4b
...
@@ -211,10 +211,27 @@ struct BlockwiseGemmWMMA
...
@@ -211,10 +211,27 @@ struct BlockwiseGemmWMMA
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I0
];
constexpr
auto
NThreadPerSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I1
];
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
AccStride
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
MAccVgprs
),
make_tuple
(
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
AccStride
)
);
#if 0
return make_naive_tensor_descriptor_packed(
return make_naive_tensor_descriptor_packed(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
// |NThreadPerSubGroup |MAccVgprs
...
@@ -225,6 +242,7 @@ struct BlockwiseGemmWMMA
...
@@ -225,6 +242,7 @@ struct BlockwiseGemmWMMA
I1,
I1,
NThreadPerSubGroup,
NThreadPerSubGroup,
MAccVgprs));
MAccVgprs));
#endif
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
0bb08f4b
...
@@ -140,8 +140,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -140,8 +140,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
)
;
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
;
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
)
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
0bb08f4b
...
@@ -151,9 +151,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -151,9 +151,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
B0EnableLds_manu
=
true
;
static
constexpr
auto
B0EnableLds_manu
=
true
;
static
constexpr
auto
B1EnableLds_manu
=
true
;
static
constexpr
auto
B1EnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
)
;
static
constexpr
auto
B0EnableLds
=
B0EnableLds_auto
||
B0EnableLds_manu
;
static
constexpr
auto
B0EnableLds
=
B0EnableLds_auto
||
B0EnableLds_manu
||
(
NumPrefetch
>
1
)
;
static
constexpr
auto
B1EnableLds
=
B1EnableLds_auto
||
B1EnableLds_manu
;
static
constexpr
auto
B1EnableLds
=
B1EnableLds_auto
||
B1EnableLds_manu
||
(
NumPrefetch
>
1
)
;
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
0bb08f4b
...
@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
)
;
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
;
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
)
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
0bb08f4b
...
@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
)
;
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
;
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
)
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
@@ -467,7 +467,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -467,7 +467,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
ck
::
get_device_name
()
==
"gfx1102"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
ck
::
half_t
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
printf
(
"DeviceOp err: AccDataType"
);
printf
(
"DeviceOp err: AccDataType"
);
return
false
;
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
0bb08f4b
...
@@ -177,8 +177,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -177,8 +177,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumGemmKPrefetchStage
>
1
)
;
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
;
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumGemmKPrefetchStage
>
1
)
;
static
constexpr
auto
conv_to_gemm_transformer
=
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
0bb08f4b
...
@@ -868,7 +868,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -868,7 +868,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc
,
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -943,7 +944,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -943,7 +944,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
1
,
1
,
1
,
1
,
B0ThreadTransferSrcResetCoordinateAfterRun
,
B0ThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
,
NumGemmKPrefetchStage
>
(
b0_grid_desc
,
b0_grid_desc
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
b0_element_op
,
b0_element_op
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
0bb08f4b
...
@@ -874,7 +874,8 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -874,7 +874,8 @@ struct GridwiseGemmMultipleD_Wmma
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc
,
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -950,7 +951,8 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -950,7 +951,8 @@ struct GridwiseGemmMultipleD_Wmma
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc
,
b_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
0bb08f4b
...
@@ -636,7 +636,8 @@ struct GridwiseGemm_Wmma
...
@@ -636,7 +636,8 @@ struct GridwiseGemm_Wmma
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc
,
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -712,7 +713,8 @@ struct GridwiseGemm_Wmma
...
@@ -712,7 +713,8 @@ struct GridwiseGemm_Wmma
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc
,
b_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
@@ -814,10 +816,11 @@ struct GridwiseGemm_Wmma
...
@@ -814,10 +816,11 @@ struct GridwiseGemm_Wmma
/*******************************************************************************/
/*******************************************************************************/
// write out to C, implement shuffle
// write out to C, implement shuffle
{
{
// C mapping in single thread.
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
//
This API Provide All dimensio
n
(
si
ze) you need
//
C mapping i
n si
ngle block
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
=
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
=
blockwise_gemm
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
blockwise_gemm
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
0bb08f4b
...
@@ -89,6 +89,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
...
@@ -89,6 +89,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
...
@@ -100,12 +101,11 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
...
@@ -100,12 +101,11 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
// * num_acc_vgprs_per_wave alone M direction
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
NPerWmma
,
bool
AssemblyBackend
,
class
FloatA
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
FloatC
>
class
FloatC
>
...
@@ -113,7 +113,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
...
@@ -113,7 +113,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
{
{
if
constexpr
(
wave_size
==
32
)
if
constexpr
(
wave_size
==
32
)
{
{
intrin_wmma_f32_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_f32_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
else
if
constexpr
(
wave_size
==
64
)
else
if
constexpr
(
wave_size
==
64
)
{
{
...
@@ -134,6 +134,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
...
@@ -134,6 +134,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
// Wave mode dependent propety
...
@@ -141,7 +142,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
...
@@ -141,7 +142,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
...
@@ -158,7 +159,6 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
...
@@ -158,7 +159,6 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
}
}
};
};
#ifdef CK_UNPACKED_ACC_DESC_LOGIC
template
<
index_t
WaveSize
>
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f16_16x16x16_f16
,
struct
wmma_type
<
WmmaInstr
::
wmma_f16_16x16x16_f16
,
WaveSize
,
WaveSize
,
...
@@ -171,6 +171,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
...
@@ -171,6 +171,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
acc_pack_number
=
2
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
// Wave mode dependent propety
...
@@ -178,12 +179,11 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
...
@@ -178,12 +179,11 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
NPerWmma
,
index_t
Opsel
,
class
FloatA
,
class
FloatA
,
class
FloatB
,
class
FloatB
,
class
FloatC
>
class
FloatC
>
...
@@ -191,15 +191,14 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
...
@@ -191,15 +191,14 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
{
{
if
constexpr
(
wave_size
==
32
)
if
constexpr
(
wave_size
==
32
)
{
{
intrin_wmma_f16_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
,
Op
se
l
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_f16_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
,
fal
se
>::
Run
(
a
,
b
,
reg_c
);
}
}
else
if
constexpr
(
wave_size
==
64
)
else
if
constexpr
(
wave_size
==
64
)
{
{
intrin_wmma_f16_16x16x16_f16_w64
<
MPerWmma
,
NPerWmma
,
Op
se
l
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_f16_16x16x16_f16_w64
<
MPerWmma
,
NPerWmma
,
fal
se
>::
Run
(
a
,
b
,
reg_c
);
}
}
}
}
};
};
template
<
index_t
WaveSize
>
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_bf16_16x16x16_bf16
,
struct
wmma_type
<
WmmaInstr
::
wmma_bf16_16x16x16_bf16
,
WaveSize
,
WaveSize
,
...
@@ -212,6 +211,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
...
@@ -212,6 +211,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
2
;
static
constexpr
index_t
acc_pack_number
=
2
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
// Wave mode dependent propety
...
@@ -219,7 +219,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
...
@@ -219,7 +219,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
template
<
index_t
MPerWmma
,
...
@@ -232,17 +232,15 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
...
@@ -232,17 +232,15 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
{
{
if
constexpr
(
wave_size
==
32
)
if
constexpr
(
wave_size
==
32
)
{
{
intrin_wmma_bf16_16x16x16_bf16_w32
<
MPerWmma
,
NPerWmma
,
Op
se
l
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_bf16_16x16x16_bf16_w32
<
MPerWmma
,
NPerWmma
,
fal
se
>::
Run
(
a
,
b
,
reg_c
);
}
}
else
if
constexpr
(
wave_size
==
64
)
else
if
constexpr
(
wave_size
==
64
)
{
{
intrin_wmma_bf16_16x16x16_bf16_w64
<
MPerWmma
,
NPerWmma
,
Op
se
l
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_bf16_16x16x16_bf16_w64
<
MPerWmma
,
NPerWmma
,
fal
se
>::
Run
(
a
,
b
,
reg_c
);
}
}
}
}
};
};
#endif
template
<
index_t
WaveSize
>
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_i32_16x16x16_iu8
,
struct
wmma_type
<
WmmaInstr
::
wmma_i32_16x16x16_iu8
,
WaveSize
,
WaveSize
,
...
@@ -255,6 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
...
@@ -255,6 +253,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
// Wave mode dependent propety
...
@@ -262,7 +261,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
...
@@ -262,7 +261,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
*
acc_pack_number
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
template
<
index_t
MPerWmma
,
...
@@ -351,7 +350,7 @@ struct WmmaSelector
...
@@ -351,7 +350,7 @@ struct WmmaSelector
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc_vgprs_per_wave
*
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc_vgprs_per_wave
*
selected_wmma
.
acc_data_size
==
selected_wmma
.
acc_data_size
*
selected_wmma
.
acc_pack_number
==
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
"WRONG! Invalid Number of Accumulator Register"
);
"WRONG! Invalid Number of Accumulator Register"
);
}
}
...
@@ -469,7 +468,7 @@ struct WmmaGemm
...
@@ -469,7 +468,7 @@ struct WmmaGemm
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
{
{
return
wmma_instr
.
num_acc_vgprs_per_wave
;
return
wmma_instr
.
num_acc_vgprs_per_wave
*
wmma_instr
.
acc_pack_number
;
}
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
wmma_instr
.
wave_size
;
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
wmma_instr
.
wave_size
;
}
...
@@ -497,12 +496,12 @@ struct WmmaGemm
...
@@ -497,12 +496,12 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!"
);
"(int8, int32) or (int4, int32)!"
);
if
constexpr
(
!
TransposeC
)
if
constexpr
(
!
TransposeC
)
{
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
}
else
else
{
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
p_b_wave
,
p_a_wave
,
p_c_thread
);
}
}
}
}
...
@@ -556,7 +555,7 @@ struct WmmaGemm
...
@@ -556,7 +555,7 @@ struct WmmaGemm
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
{
{
return
make_tuple
(
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{});
return
make_tuple
(
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{}
,
Number
<
wmma_instr
.
acc_pack_number
>
{}
);
}
}
};
};
...
...
include/ck/utility/amd_wmma.hpp
View file @
0bb08f4b
...
@@ -12,11 +12,11 @@ namespace ck {
...
@@ -12,11 +12,11 @@ namespace ck {
/********************************WAVE32 MODE***********************************************/
/********************************WAVE32 MODE***********************************************/
// src: fp16, dst: fp32
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
AssemblyBackend
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32
;
struct
intrin_wmma_f32_16x16x16_f16_w32
;
template
<
bool
AssemblyBackend
>
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
,
AssemblyBackend
>
struct
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>
{
{
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment