Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
63f87662
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "12856c3b39287eabe15b6f2f1a6bbea9a6e33f3a"
Commit
63f87662
authored
Dec 15, 2022
by
aska-0096
Browse files
tidy up
parent
13af8cc4
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
274 additions
and
1034 deletions
+274
-1034
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+6
-13
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+224
-720
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+6
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+20
-238
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+0
-3
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+18
-57
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
63f87662
...
@@ -22,20 +22,13 @@ using CElementOp = PassThrough;
...
@@ -22,20 +22,13 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
// clang-format off
// using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmWmma
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWMMA|NMMMA| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 6, 1>;
// clang-format on
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MWmma|NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
128
,
128
,
8
,
8
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
63f87662
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
63f87662
...
@@ -201,7 +201,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -201,7 +201,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_wmma
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_wmma
<
BlockSize
,
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
BDataType
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
...
@@ -353,7 +354,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -353,7 +354,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
{
const
auto
kernel
=
kernel_gemm_wmma
<
const
auto
kernel
=
kernel_gemm_wmma
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
...
@@ -384,7 +386,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -384,7 +386,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
{
const
auto
kernel
=
kernel_gemm_wmma
<
const
auto
kernel
=
kernel_gemm_wmma
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
63f87662
...
@@ -18,7 +18,8 @@
...
@@ -18,7 +18,8 @@
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
...
@@ -33,8 +34,8 @@ __global__ void
...
@@ -33,8 +34,8 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_wmma
(
kernel_gemm_wmma
(
const
FloatA
B
*
__restrict__
p_a_grid
,
const
FloatA
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
...
@@ -77,7 +78,8 @@ __global__ void
...
@@ -77,7 +78,8 @@ __global__ void
}
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatC
,
...
@@ -216,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -216,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
Float
A
B
);
return
(
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
b_block_space_size_aligned
*
sizeof
(
FloatB
)
);
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
@@ -270,120 +272,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -270,120 +272,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
}
// Vector write
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
MPerWmma
,
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_m_n
);
}
// Per pixel
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
MPerWmma
,
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(
c_grid_desc_m_n
);
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
...
@@ -410,11 +298,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -410,11 +298,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
c_grid_desc_m_n
);
}
}
// using
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// = remove_cvref_t<decltype(
// MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup(
// CGridDesc_M_N{}))>;
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
...
@@ -422,17 +306,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -422,17 +306,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
__device__
static
void
Run
(
const
FloatA
B
*
__restrict__
p_a_grid
,
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup&
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
...
@@ -476,8 +357,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -476,8 +357,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
B
,
/* typename SrcData, */
FloatA
,
/* typename DstData, */
FloatA
B
,
/* typename DstData, */
FloatA
,
/* typename SrcDesc, */
decltype
(
a_grid_desc_k0_m_k1
),
/* typename SrcDesc, */
decltype
(
a_grid_desc_k0_m_k1
),
/* typename DstDesc, */
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
/* typename DstDesc, */
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
...
@@ -496,8 +377,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -496,8 +377,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
a_block_desc_k0perblock_mperblock_k1
,
a_block_desc_k0perblock_mperblock_k1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// printf("BlockSliceLengths K0 = %d, M = %d, K1 = %d\n", K0PerBlock, MPerBlock, K1());
// printf("a_block_wise_copy: %s\n", std::string(type_name<decltype(a_blockwise_copy)>()).c_str());
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
...
@@ -508,8 +387,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -508,8 +387,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
Float
A
B
,
FloatB
,
Float
A
B
,
FloatB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -530,18 +409,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -530,18 +409,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
/*******************************************************************************/
/*******************************************************************************/
// GEMM definition
// GEMM
// c_mtx += a_mtx * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in register
constexpr
auto
WmmaK
=
16
;
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
<
BlockSize
,
FloatAB
,
FloatA
,
FloatB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
...
@@ -557,8 +432,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -557,8 +432,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
/*******************************************************************************/
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
B
*>
(
p_shared
),
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
());
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
A
B
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
());
// Shift Per SUB_K
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
@@ -582,101 +457,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -582,101 +457,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
c_thread_buf
,
c_thread_buf
,
K0BlockMainLoop
);
K0BlockMainLoop
);
/*******************************************************************************/
/*******************************************************************************/
#ifdef CK_EXPERIMENTAL_ARBITRARY_WRITEOUT
// write out to C, implement shuffle
// write out C matrix, c shuffle not implemented
{
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
){
char
info
[
4
];
info
[
0
]
=
'C'
;
info
[
1
]
=
i
/
10
+
'0'
;
info
[
2
]
=
i
%
10
+
'0'
;
info
[
3
]
=
'\0'
;
debug_hexprinter
(
0xffffffff
,
c_thread_buf
[
Number
<
i
>
{}],
info
);
});
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
// This API Provide All dimension (size) you need
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
constexpr
auto
MWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I1
);
constexpr
auto
MSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I2
);
constexpr
auto
NWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I4
);
constexpr
auto
NThreadPerSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I5
);
constexpr
auto
MAccVgprs
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I6
);
// printf("MWave = %d, MSubGroup = %d, NWave = %d, NThreadPerSubGroup = %d, MAccVgprs = %d\n", MWave, MSubGroup, NWave, NThreadPerSubGroup, MAccVgprs);
// Mapping
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
);
const
index_t
m_thread_data_on_grid
=
m_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
// Checked
// debug_hexprinter(0xffffffff, m_thread_data_on_grid, "c_m");
// debug_hexprinter(0xffffffff, n_thread_data_on_grid, "c_n");
const
auto
m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MRepeat
,
MWave
,
MSubGroup
,
MAccVgprs
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
NRepeat
,
NWave
,
NThreadPerSubGroup
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
debug_hexprinter
(
0x4
,
MRepeat
,
"mblockxrepeat"
);
debug_hexprinter
(
0x2
,
MWave
,
"mwave"
);
debug_hexprinter
(
0x2
,
MSubGroup
,
"msubgroup"
);
debug_hexprinter
(
0x8
,
MAccVgprs
,
"maccvgprs"
);
debug_hexprinter
(
0x4
,
NWave
,
"nwave"
);
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
// printf("write out dimension access order = (%d, %d, %d, %d, %d, %d, %d)\n", CThreadTransferSrcDstAccessOrder{}[Number<0>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<1>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<2>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<3>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<4>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<5>{}].value, CThreadTransferSrcDstAccessOrder{}[Number<6>{}].value);
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
/* typename SrcData */
FloatAcc
,
/* typename DstData */
FloatC
,
/* typename SrcDesc */
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
/* typename DstDesc */
decltype
(
c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup
),
/* typename ElementwiseOperation */
CElementwiseOperation
,
// Thread register Mapping 0 1 2 4 5 6 3
/* typename SliceLengths */
Sequence
<
MRepeat
,
I1
,
I1
,
NRepeat
,
I1
,
I1
,
MAccVgprs
>
,
/* typename DimAccessOrder */
CThreadTransferSrcDstAccessOrder
,
/* index_t DstVectorDim */
CThreadTransferSrcDstVectorDim
,
/* index_t DstScalarPerVector */
CThreadTransferDstScalarPerVector
,
/* InMemoryDataOperationEnum DstInMemOp */
CGlobalMemoryDataOperation
,
/* index_t DstScalarStrideInVector */
1
,
/* bool DstResetCoordinateAfterRun */
true
>
{
/* dst_desc */
c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup
,
/* dst_slice_origin_idx */
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
n_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I2
]),
/* element_op */
c_element_op
};
c_thread_copy
.
Run
(
/* c_thread_desc */
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
,
/* c_register_beginning*/
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
/* c_local(register) */
c_thread_buf
,
/* c_grid_desc */
c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup
,
/* c_grid_buf */
c_grid_buf
);
}
#endif
{
{
// write out to C, implement shuffle
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
63f87662
...
@@ -128,12 +128,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -128,12 +128,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
// printf("src_access_lengths: %d, %d, %d\n", (src_access_lengths[Number<0>{}])(), src_access_lengths[Number<1>{}](), src_access_lengths[Number<2>{}]());
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// printf("ordered_src_access_lengths: %d, %d, %d\n", (ordered_src_access_lengths[Number<0>{}])(), ordered_src_access_lengths[Number<1>{}](), ordered_src_access_lengths[Number<2>{}]());
// make forward steps
// make forward steps
const
auto
src_forward_steps
=
generate_tuple
(
const
auto
src_forward_steps
=
generate_tuple
(
...
@@ -210,7 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -210,7 +208,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
// apply SrcElementwiseOperation on src_vector_container
// apply SrcElementwiseOperation on src_vector_container
// debug_hexprinter(0xffffffff, src_coord_.GetOffset());
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
SrcData
src_v
;
SrcData
src_v
;
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
63f87662
...
@@ -283,51 +283,51 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
...
@@ -283,51 +283,51 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
}
}
};
};
template
<
typename
src_type
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
>
template
<
typename
src_type
_a
,
typename
src_type_b
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
>
struct
WmmaSelector
struct
WmmaSelector
{
{
template
<
typename
src_type_
,
typename
dst_type_
,
index_t
MPerWmma_
,
index_t
NPerWmma_
>
template
<
typename
src_type_
a_
,
typename
src_type_b_
,
typename
dst_type_
,
index_t
MPerWmma_
,
index_t
NPerWmma_
>
static
constexpr
auto
GetWmma
();
static
constexpr
auto
GetWmma
();
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
float
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
float
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
floa
t
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
in
t
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
}
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
floa
t
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
int4_t
,
in
t
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
}
#endif
#endif
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static
constexpr
auto
selected_wmma
=
static
constexpr
auto
selected_wmma
=
wmma_type
<
GetWmma
<
src_type
,
dst_type
,
MPerWmma
,
NPerWmma
>
(),
Number
<
32
>
{}
>
{};
wmma_type
<
GetWmma
<
src_type
_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
(),
Number
<
32
>
{}
>
{};
__host__
__device__
constexpr
WmmaSelector
()
__host__
__device__
constexpr
WmmaSelector
()
{
{
...
@@ -344,7 +344,8 @@ struct WmmaSelector
...
@@ -344,7 +344,8 @@ struct WmmaSelector
}
}
};
};
template
<
typename
src_type
,
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
dst_type
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
NPerWmma
,
...
@@ -412,46 +413,6 @@ struct WmmaGemm
...
@@ -412,46 +413,6 @@ struct WmmaGemm
Sequence
<
5
>
{}));
Sequence
<
5
>
{}));
}
}
// Per-Pixel write
template
<
typename
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
>
__host__
__device__
static
constexpr
auto
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
(
const
CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA
&
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
)
{
const
auto
MBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I0
);
const
auto
NBlockxRepeat
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I3
);
const
auto
MWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I1
);
const
auto
NWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
,
make_tuple
(
make_pass_through_transform
(
MBlockxRepeat
),
make_pass_through_transform
(
MWave
),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_subgroups
>
{},
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{})),
make_pass_through_transform
(
NBlockxRepeat
),
make_pass_through_transform
(
NWave
),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_thread_per_subgroups
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}));
}
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
{
{
return
wmma_instr
.
num_acc_vgprs_per_wave
;
return
wmma_instr
.
num_acc_vgprs_per_wave
;
...
@@ -463,13 +424,13 @@ struct WmmaGemm
...
@@ -463,13 +424,13 @@ struct WmmaGemm
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
{
static_assert
(
static_assert
(
(
is_same
<
src_type
,
half_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type
_a
,
half_t
>::
value
&&
is_same
<
src_type_b
,
half_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type
_a
,
bhalf_t
>::
value
&&
is_same
<
src_type_b
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type
,
half_t
>::
value
&&
is_same
<
dst_type
,
half_t
>::
value
)
||
(
is_same
<
src_type
_a
,
half_t
>::
value
&&
is_same
<
src_type_b
,
half_t
>::
value
&&
is_same
<
dst_type
,
half_t
>::
value
)
||
(
is_same
<
src_type
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
bhalf_t
>::
value
)
||
(
is_same
<
src_type
_a
,
bhalf_t
>::
value
&&
is_same
<
src_type_b
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
bhalf_t
>::
value
)
||
(
is_same
<
src_type
,
int8_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
(
is_same
<
src_type
_a
,
int8_t
>::
value
&&
is_same
<
src_type_b
,
int8_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
(
is_same
<
src_type
,
int4_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
||
(
is_same
<
src_type
_a
,
int4_t
>::
value
&&
is_same
<
src_type_b
,
int4_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
#endif
#endif
,
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
...
@@ -518,7 +479,7 @@ struct WmmaGemm
...
@@ -518,7 +479,7 @@ struct WmmaGemm
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
}
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type
_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment