Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d4adc71a
Commit
d4adc71a
authored
Feb 24, 2023
by
aska-0096
Browse files
Mat-A LDS Bypass sanity pass
parent
c811a0e9
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
887 additions
and
342 deletions
+887
-342
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+40
-6
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+12
-0
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+148
-51
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+100
-92
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+122
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+396
-160
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+22
-17
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+11
-4
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+8
-5
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+24
-0
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
d4adc71a
...
...
@@ -19,15 +19,49 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
128
,
256
,
8
,
8
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
>
;
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
8
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
1
,
// M Repeat
8
,
// N-Repeat
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
// C shuffle (M Repeat) Per store
4
,
// C shuffle (N Repeat) Per store
S
<
1
,
64
,
1
,
4
>
,
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/01_gemm/run_gemm_example.inc
View file @
d4adc71a
...
...
@@ -35,6 +35,18 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
case
2
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
break
;
case
3
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
break
;
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
d4adc71a
...
...
@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
// warm up
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
const
int
nrepeat
=
10
;
const
int
nrepeat
=
10
0
;
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
d4adc71a
...
...
@@ -62,12 +62,33 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I4
);
static
constexpr
auto
A_temp0
=
Number
<
ABlockDesc
{}.
GetLength
(
I0
)
>
{};
static
constexpr
auto
A_temp1
=
Number
<
ABlockDesc
{}.
GetLength
(
I1
)
>
{};
static
constexpr
auto
A_temp2
=
Number
<
ABlockDesc
{}.
GetLength
(
I2
)
>
{};
static
constexpr
auto
A_temp3
=
Number
<
ABlockDesc
{}.
GetLength
(
I3
)
>
{};
static
constexpr
auto
A_temp4
=
Number
<
ABlockDesc
{}.
GetLength
(
I4
)
>
{};
// FIX it, workaround
using
ABlockDesc_temp
=
decltype
(
make_naive_tensor_descriptor
(
make_tuple
(
A_temp0
,
A_temp1
,
A_temp2
,
A_temp3
,
A_temp4
),
make_tuple
(
A_temp1
*
A_temp2
*
A_temp3
*
A_temp4
,
A_temp2
*
A_temp3
*
A_temp4
,
A_temp3
*
A_temp4
,
A_temp4
,
I1
)));
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
bool
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
bool
BEnableLds
=
MWaves
==
1
?
false
:
true
;
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static
constexpr
index_t
A_Data_Duplicated_Rate
=
AEnableLds
?
2
:
1
;
static
constexpr
index_t
B_Data_Duplicated_Rate
=
BEnableLds
?
2
:
1
;
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
...
...
@@ -91,26 +112,38 @@ struct BlockwiseGemmWMMA
// Default, Block buffer in LDS, thread level offset enabled
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
if
constexpr
(
AEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
WMMA_a_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
if
constexpr
(
BEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
WMMA_b_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
);
}
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
)
...
...
@@ -269,7 +302,7 @@ struct BlockwiseGemmWMMA
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
ABlockDesc
_temp
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
...
...
@@ -285,8 +318,10 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
),
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
*
A_Data_Duplicated_Rate
/
2
>
{},
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
...
...
@@ -294,8 +329,13 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
),
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
*
B_Data_Duplicated_Rate
/
2
>
{},
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
...
...
@@ -324,6 +364,7 @@ struct BlockwiseGemmWMMA
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
...
...
@@ -340,7 +381,13 @@ struct BlockwiseGemmWMMA
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
template
<
bool
EnableLds
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
...
...
@@ -349,8 +396,33 @@ struct BlockwiseGemmWMMA
4
,
A_K1
,
A_K1
>
;
};
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
template
<
>
struct
AThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
A_K1
,
0x76543210
,
0xfedcba98
,
true
>
;
};
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
...
...
@@ -359,9 +431,28 @@ struct BlockwiseGemmWMMA
4
,
B_K1
,
B_K1
>
;
};
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
B_K1
,
0x76543210
,
0xfedcba98
,
false
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
// block wise level pipe designed for inline asm
...
...
@@ -407,8 +498,14 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
,
AssemblyBackend
>
{};
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
,
AssemblyBackend
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
d4adc71a
...
...
@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -35,10 +36,10 @@ template <typename ALayout,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K
0
PerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerW
MMA
,
ck
::
index_t
NPerW
MMA
,
ck
::
index_t
MPerW
mma
,
ck
::
index_t
NPerW
mma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
...
...
@@ -75,19 +76,31 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
)
;
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
16
;
const
index_t
K0
=
K
/
K1
;
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Describe how data read from Global memory
static
auto
MakeAGridDescriptor
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
#ifdef ENABLE_COLMAJOR
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
...
...
@@ -97,104 +110,88 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
#endif
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
if
constexpr
(
AEnableLds
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_
right_pad_transform
(
M
,
PadM
)),
make_
pass_through_transform
(
a_grid_desc_m_k
.
GetLength
(
I0
)
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
constexpr
auto
A_KRow
=
WmmaK
/
K1
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
Raw
,
index_t
N
Raw
,
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
Raw
,
index_t
N
Raw
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m
_n
=
[
&
]()
{
const
auto
c_grid_desc_m
raw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
}
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc
_K0_M_K1
=
decltype
(
MakeAGridDescriptor
_K0_M_K1
(
1
,
1
,
1
));
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
...
...
@@ -207,7 +204,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CShuffleDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc
_K0_M_K1
,
AGridDesc
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
...
...
@@ -215,9 +212,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K
0
PerBlock
,
MPerW
MMA
,
NPerW
MMA
,
KPerBlock
,
MPerW
mma
,
NPerW
mma
,
K1
,
MRepeat
,
NRepeat
,
...
...
@@ -228,6 +225,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -236,6 +234,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
...
...
@@ -265,7 +264,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_
k0_m_k1_
{},
a_grid_desc_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock
{},
...
...
@@ -276,8 +275,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
a_grid_desc_
=
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmWmma_CShuffle
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmWmma_CShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
...
...
@@ -285,10 +283,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
@@ -300,7 +296,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc
_K0_M_K1
a_grid_desc_
k0_m_k1_
;
AGridDesc
a_grid_desc_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -322,9 +318,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
#if 0
{
std::cout << "arg.a_grid_desc_
k0_m_k1_
{" << arg.a_grid_desc_
k0_m_k1_
.GetLength(I0)
<< ", " << arg.a_grid_desc_
k0_m_k1_
.GetLength(I1) << ", "
<< arg.a_grid_desc_
k0_m_k1_
.GetLength(I2) << "}" << std::endl;
std::cout << "arg.a_grid_desc_{" << arg.a_grid_desc_.GetLength(I0)
<< ", " << arg.a_grid_desc_.GetLength(I1) << ", "
<< arg.a_grid_desc_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
...
...
@@ -336,7 +332,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
...
...
@@ -348,8 +344,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
auto
GetK
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I2
);
}
else
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I3
)
*
arg
.
a_grid_desc_
.
GetLength
(
I5
);
}
};
const
auto
K
=
GetK
();
float
ave_time
=
0
;
...
...
@@ -360,7 +366,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ADataType
,
BDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
...
...
@@ -378,7 +384,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
k0_m_k1_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
...
...
@@ -393,7 +399,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ADataType
,
BDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
...
...
@@ -411,7 +417,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
k0_m_k1_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
...
...
@@ -443,15 +449,17 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
printf
(
"DeviceOp err: AccDataType"
);
return
false
;
}
}
else
{
printf
(
"DeviceOp err: Arch"
);
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
...
...
@@ -547,10 +555,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K
0
PerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
K1
<<
", "
<<
MPerW
MMA
<<
", "
<<
NPerW
MMA
<<
", "
<<
MPerW
mma
<<
", "
<<
NPerW
mma
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
">"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
d4adc71a
...
...
@@ -15,6 +15,8 @@ enum struct PipelineVersion
};
template
<
PipelineVersion
PipelineVer
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
index_t
NumPrefetch
=
1
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
constexpr
auto
GridwiseGemmPipeline_Selector
()
...
...
@@ -23,7 +25,7 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
GridwiseGemmPipeline_v1
<
NumPrefetch
>
{};
return
GridwiseGemmPipeline_v1
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
d4adc71a
...
...
@@ -8,12 +8,12 @@
namespace
ck
{
template
<
index_t
NumPrefetch
>
template
<
index_t
NumPrefetch
,
bool
AEnableLds
,
bool
BEnableLds
>
struct
GridwiseGemmPipeline_v1
;
// 1-stage prefetch
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
>
struct
GridwiseGemmPipeline_v1
<
1
,
true
,
true
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -107,7 +107,7 @@ struct GridwiseGemmPipeline_v1<1>
// 2-stage prefetch
template
<
>
struct
GridwiseGemmPipeline_v1
<
2
>
struct
GridwiseGemmPipeline_v1
<
2
,
true
,
true
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -253,6 +253,123 @@ struct GridwiseGemmPipeline_v1<2>
}
};
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
,
false
,
true
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
#if 0
constexpr auto a_block_origin_idx = generate_sequence_v2(
[]() constexpr {
return Number<0>{};
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
constexpr
auto
a_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
a_block_buf_switch
=
a_block_buf
;
// preload data into LDS
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_origin_idx
,
a_block_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_origin_idx
,
a_block_buf_switch
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
a_block_buf
=
a_block_buf_switch
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
}
}
};
// placeholder
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
,
true
,
false
>
{
};
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
,
false
,
false
>
{
};
template
<
index_t
NumPrefetch
>
struct
GridwiseGemmPipelineInterwave_v1
;
...
...
@@ -348,7 +465,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template
<
>
struct
GridwiseGemmPipelineInterwave_v1
<
2
>
:
public
GridwiseGemmPipeline_v1
<
2
>
struct
GridwiseGemmPipelineInterwave_v1
<
2
>
:
public
GridwiseGemmPipeline_v1
<
2
,
true
,
true
>
{
};
...
...
@@ -358,7 +475,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
GridwiseGemmPipeline_v1
<
NumPrefetch
>
{};
return
GridwiseGemmPipeline_v1
<
NumPrefetch
,
true
,
true
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
d4adc71a
...
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AGridDesc
_K0_M_K1
,
typename
AGridDesc
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
...
...
@@ -33,30 +33,26 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_wmma
(
const
FloatA
*
__restrict__
p_a_grid
,
kernel_gemm_wmma
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc
_K0_M_K1
a_grid_desc
_k0_m_k1
,
const
AGridDesc
a_grid_desc
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
// const
// CGridDescriptor_MBlockxRepeat_MWave_MSubGroup_MAccVgprs_NBlockxRepeat_NWave_NThreadPerSubGroup
// c_grid_desc_mblockxrepeat_mwave_msubgroup_maccvgprs_nblockxrepeat_nwave_nthreadpersubgroup,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__
char
p_shared
[
GridwiseGemm
::
Get
SharedMem
oryNumberOfByte
()
];
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMem
Trait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
a_grid_desc
_k0_m_k1
,
a_grid_desc
,
b_grid_desc_k0_n_k1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
...
...
@@ -67,7 +63,7 @@ __global__ void
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc
_k0_m_k1
;
ignore
=
a_grid_desc
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
...
...
@@ -84,7 +80,7 @@ template <index_t BlockSize,
typename
FloatCShuffle
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc
_K0_M_K1
,
typename
AGridDesc
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
...
...
@@ -92,7 +88,7 @@ template <index_t BlockSize,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K
0
PerBlock
,
index_t
KPerBlock
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
K1Value
,
...
...
@@ -105,6 +101,7 @@ template <index_t BlockSize,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AEnableLds
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -113,6 +110,7 @@ template <index_t BlockSize,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BEnableLds
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
...
...
@@ -132,39 +130,149 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
B_K0
=
BGridDesc_K0_N_K1
{}.
GetLength
(
I0
);
static
constexpr
auto
B_K1
=
BGridDesc_K0_N_K1
{}.
GetLength
(
I2
);
// FIX ME: To be deprecated
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
16
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
AEnableLds
,
BEnableLds
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
(
const
ABlockDesc_AK0_M_AK1
&
)
// Describe how data store to (LDS/VGPR) buffer from Global memory
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor
()
{
constexpr
auto
a_block_desc
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
// K0->M->K1 Per Block
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
max_lds_align
=
K1
;
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}
else
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
K1
),
make_tuple
(
Number
<
MRepeat
>
{}
*
K1
,
K1
,
K1
,
K1
,
K1
,
I1
));
}
}();
return
a_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
return
make_multi_index
(
K0PerBlock
,
0
,
0
);
}
else
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
}
}();
return
a_block_copy_step
;
}
__host__
__device__
static
constexpr
auto
MakeBBlockSliceCopyStep
()
{
constexpr
auto
b_block_copy_step
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
return
make_multi_index
(
K0PerBlock
,
0
,
0
);
}
else
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
}
}();
return
b_block_copy_step
;
}
// Describe how data read from (LDS/VGPR) buffer
template
<
typename
ABlockDesc_
>
__host__
__device__
static
constexpr
auto
MakeAWaveDescriptor
(
const
ABlockDesc_
&
)
{
constexpr
index_t
A_K0
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I0
);
constexpr
index_t
A_K1
=
ABlockDesc_AK0_M_AK1
{}.
GetLength
(
I2
);
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
constexpr
auto
a_wave_desc
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
ABlockDesc_
AK0_M_AK1
{},
ABlockDesc_
{},
make_tuple
(
make_pass_through_transform
(
Number
<
A_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
else
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I5
);
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_merge_transform
(
make_tuple
(
Number
<
KWmma
>
{},
I1
)),
make_pass_through_transform
(
Number
<
MRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}();
return
a_wave_desc
;
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
B_K0
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I2
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
return
transform_tensor_descriptor
(
BBlockDesc_BK0_N_BK1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
...
...
@@ -175,32 +283,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
a_block_desc_k0perblock_mperblock_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
...
...
@@ -223,44 +309,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
>
{},
Number
<
CShuffleMRepeatPerShuffle
*
MWave
s
*
MPerWmma
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
{}));
Number
<
CShuffleNRepeatPerShuffle
*
NWave
s
*
NPerWmma
>
{}));
return
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
b_block_space_size_aligned
*
sizeof
(
FloatB
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc
_K0_M_K1
&
a_grid_desc
_k0_m_k1
,
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
...
...
@@ -272,23 +334,68 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
(
NPerBlock
%
(
NRepeat
*
NPerWmma
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
GetAProblemsizeMK
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
));
}
else
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
a_grid_desc
.
GetLength
(
I4
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
));
}
};
const
auto
GetBProblemsizeNK
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
return
make_tuple
(
b_grid_desc_k0_n_k1
.
GetLength
(
I1
),
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
*
b_grid_desc_k0_n_k1
.
GetLength
(
I2
));
}
else
{
return
make_tuple
(
b_grid_desc_k0_n_k1
.
GetLength
(
I1
)
*
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)
*
b_grid_desc_k0_n_k1
.
GetLength
(
I4
),
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
*
b_grid_desc_k0_n_k1
.
GetLength
(
I3
)
*
b_grid_desc_k0_n_k1
.
GetLength
(
I5
));
}
};
const
auto
M
=
GetAProblemsizeMK
()[
I0
];
const
auto
N
=
GetBProblemsizeNK
()[
I0
];
const
auto
K
=
GetAProblemsizeMK
()[
I1
];
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
K
==
GetBProblemsizeNK
()[
I1
]))
{
printf
(
"A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d
\n
"
,
GetAProblemsizeMK
()[
I0
],
GetAProblemsizeMK
()[
I1
],
GetBProblemsizeNK
()[
I0
],
GetBProblemsizeNK
()[
I1
],
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
printf
(
"GridwiseOp err: ProblemSize check"
);
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
{
printf
(
"GridwiseOp err: ProblemSize division"
);
return
false
;
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
0
/
K
0
PerBlock
;
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
printf
(
"GridwiseOp err: Pipeline not support this k_loop"
);
return
false
;
}
...
...
@@ -303,7 +410,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0
PerBlock
*
K1
)
;
const
index_t
num_loop
=
K
/
K
PerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
...
...
@@ -340,12 +447,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
max_lds_align
=
K1
;
static
constexpr
auto
a_block_space_size_aligned
=
AEnableLds
?
math
::
integer_least_multiple
(
MakeABlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
*
sizeof
(
FloatA
)
:
0
;
static
constexpr
auto
b_block_space_size_aligned
=
BEnableLds
?
math
::
integer_least_multiple
(
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
().
GetElementSpaceSize
(),
max_lds_align
)
*
sizeof
(
FloatB
)
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b_block_space_offset
=
a_block_space_size_aligned
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_space_size
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
.
GetElementSpaceSize
()
*
sizeof
(
FloatCShuffle
);
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
lds_size
=
math
::
max
(
c_shuffle_block_space_size
,
(
a_block_space_size_aligned
+
b_block_space_size_aligned
));
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc
_K0_M_K1
&
a_grid_desc
_k0_m_k1
,
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -358,7 +498,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
// Memory buffer zone.
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc
_k0_m_k1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -378,14 +518,32 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy
const
auto
K
=
[
&
](){
if
constexpr
(
AEnableLds
){
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
}
else
{
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
);
}
}();
// printf("---------------K = %d\n", K);
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b_block_desc
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
auto
a_block_trait
=
[
&
](){
// A matrix blockwise copy
if
constexpr
(
AEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
a_block_desc
.
GetElementSpaceSize
());
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
/* typename SrcElementwiseOperation, */
AElementwiseOperation
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
...
...
@@ -394,8 +552,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
,
/* typename DstData, */
FloatA
,
/* typename SrcDesc, */
decltype
(
a_grid_desc
_k0_m_k1
),
/* typename DstDesc, */
decltype
(
a_block_desc
_k0perblock_mperblock_k1
),
/* typename SrcDesc, */
decltype
(
a_grid_desc
),
/* typename DstDesc, */
decltype
(
a_block_desc
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename DstDimAccessOrder, */
Sequence
<
0
,
1
,
2
>
,
/* index_t SrcVectorDim, */
ABlockTransferSrcVectorDim
,
...
...
@@ -406,14 +564,60 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
a_grid_desc
_k0_m_k1
,
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc
_k0perblock_mperblock_k1
,
a_block_desc
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
return
make_tuple
(
a_block_buf
,
a_blockwise_copy
);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
a_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatA
,
FloatA
,
decltype
(
a_grid_desc
),
decltype
(
a_block_desc
),
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
ABlockTransferSrcScalarPerVector
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
get_thread_local_1d_id
()
/
32
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
a_block_buf
,
a_blockwise_copy
);
}
};
auto
b_block_trait
=
[
&
](){
if
constexpr
(
BEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_size_aligned
,
b_block_desc
.
GetElementSpaceSize
());
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
...
...
@@ -425,7 +629,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
FloatB
,
FloatB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc
_k0perblock_nperblock_k1
),
decltype
(
b_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
...
...
@@ -439,13 +643,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc
_k0perblock_nperblock_k1
,
b_block_desc
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
);
}
else
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_block_desc
.
GetElementSpaceSize
());
auto
b_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc
),
Sequence
<
Number
<
K0PerBlock
>
{},
Number
<
NRepeat
>
{},
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BBlockTransferSrcScalarPerVector
,
1
>
(
make_multi_index
(
0
,
get_thread_local_1d_id
()
/
32
*
16
+
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
);
}
};
auto
a_block_buf
=
a_block_trait
()[
I0
];
auto
a_blockwise_copy
=
a_block_trait
()[
I1
];
auto
b_block_buf
=
b_block_trait
()[
I0
];
auto
b_blockwise_copy
=
b_block_trait
()[
I1
];
/*******************************************************************************/
// GEMM
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
...
...
@@ -453,11 +686,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
FloatA
,
FloatB
,
FloatAcc
,
decltype
(
MakeA
Block
Descriptor
_K0_M0_M1_M2_K1
(
a_block_desc_k0perblock_mperblock_k1
)),
decltype
(
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
b_block_desc
_k0perblock_nperblock_k1
)),
decltype
(
MakeA
Wave
Descriptor
(
a_block_desc
)),
decltype
(
MakeBBlockDescriptor_K0_N0_N1_N2_K1
(
b_block_desc
)),
MPerBlock
,
NPerBlock
,
K
0
PerBlock
*
K1
,
KPerBlock
,
MPerWmma
,
NPerWmma
,
MRepeat
,
...
...
@@ -468,25 +701,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
/*******************************************************************************/
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
),
a_block_desc_k0perblock_mperblock_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
());
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
MakeABlockSliceCopyStep
();
// printf("a_block_slice_copy_step FirstKdim = %d\n", a_block_slice_copy_step[I0]);
constexpr
auto
b_block_slice_copy_step
=
MakeBBlockSliceCopyStep
();
// gridwise GEMM pipeline
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
0
/
K
0
PerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
_k0_m_k1
,
a_block_desc
_k0perblock_mperblock_k1
,
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
,
a_block_desc
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_block_desc
_k0perblock_nperblock_k1
,
b_block_desc
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
...
...
@@ -497,6 +726,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
/*******************************************************************************/
// write out to C, implement shuffle
{
#if 0
static_for<0, c_thread_buf.Size(), 1>{}([&](auto i) {
printf("tid: %03d, c_thread_buf[%02d] val: %08x\n", get_thread_local_1d_id(), i.value,
*(reinterpret_cast<const uint32_t*>(&(c_thread_buf[i]))));
// c_thread_buf(i) = 32;
});
#endif
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
...
...
@@ -515,8 +751,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared
),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
());
static_cast
<
FloatCShuffle
*>
(
p_shared
),
SharedMemTrait
::
c_shuffle_block_space_size
);
constexpr
auto
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
,
...
...
@@ -666,6 +901,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
d4adc71a
...
...
@@ -1324,15 +1324,14 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
(
const
ElementwiseOperation
&
element_op
)
:
element_op_
{
element_op
}
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
(
const
Index
&
src_idx
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! Not divisible"
);
ignore
=
src_idx
;
}
template
<
typename
SrcSliceOriginIdx
,
...
...
@@ -1344,7 +1343,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
...
...
@@ -1383,7 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
//
idx_md err. as dst access 2 strided elements while src visit 1 per loop
//
src_desc error, non constexpr?
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
...
...
@@ -1396,16 +1395,22 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// apply element-wise operation
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply intra-row swizzle permute
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, RawData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) ); apply intra-row swizzle permute
if
constexpr
(
IntraRowSwizzlePerm
)
{
// origin:
// 0xfedcba98,
// 0x76543210
temp
=
__builtin_amdgcn_permlane16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
0xeca86420
,
0xfdb97531
,
1
,
0
);
v_this_row
=
type_convert
<
float
>
(
temp
);
temp
=
__builtin_amdgcn_permlane16
(
// 0x76543210, 0xfedcba98
temp
,
type_convert
<
int
>
(
v_this_row
),
0xb3a29180
,
0xf7e6d5c4
,
1
,
0
);
v_this_row
=
type_convert
<
SrcData
>
(
temp
);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, SwiData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) );
}
// apply inter-row permute.
...
...
@@ -1415,8 +1420,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
HighEightRowLaneIdx
,
1
,
0
);
v_theother_row
=
type_convert
<
float
>
(
temp
);
v_theother_row
=
type_convert
<
SrcData
>
(
temp
);
// printf("tid: %03d, PermData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_theother_row)) );
if
(
get_thread_local_1d_id
()
%
32
<
16
)
{
// apply type convert
...
...
@@ -1434,8 +1440,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
});
});
}
ElementwiseOperation
element_op_
;
ElementwiseOperation
element_op_
{};
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
d4adc71a
...
...
@@ -103,7 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
bool
AssemblyBackend
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
bool
AssemblyBackend
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
...
...
@@ -492,11 +497,13 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!"
);
if
constexpr
(
!
TransposeC
)
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
else
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
}
}
...
...
include/ck/utility/amd_wmma.hpp
View file @
d4adc71a
...
...
@@ -21,12 +21,15 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
if
constexpr
(
AssemblyBackend
){
if
constexpr
(
AssemblyBackend
)
{
amd_assembly_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
}
else
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
else
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
}
...
...
include/ck/utility/data_type.hpp
View file @
d4adc71a
...
...
@@ -988,6 +988,30 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x)
return
u
.
fp32
;
}
template
<
>
inline
__host__
__device__
constexpr
int
type_convert
<
int
,
half_t
>
(
half_t
x
)
{
union
{
half_t
fp16
;
int
int32
;
}
u
=
{
x
};
return
u
.
int32
;
}
template
<
>
inline
__host__
__device__
constexpr
half_t
type_convert
<
half_t
,
int
>
(
int
x
)
{
union
{
int
int32
;
half_t
fp16
;
}
u
=
{
x
};
return
u
.
fp16
;
}
// convert fp32 to bfp16
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment