Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
047cee2b
Commit
047cee2b
authored
Jul 20, 2022
by
Anthony Chang
Browse files
compiles
parent
68b71534
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1035 additions
and
214 deletions
+1035
-214
example/01_gemm/gemm_gemm_xdl_fp16.cpp
example/01_gemm/gemm_gemm_xdl_fp16.cpp
+89
-24
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+145
-62
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+15
-1
include/ck/tensor_operation/gpu/device/device_gemm_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/device_gemm_gemm_xdl_cshuffle.hpp
+134
-69
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_cshuffle_v1.hpp
...operation/gpu/grid/gridwise_gemm_gemm_xdl_cshuffle_v1.hpp
+450
-55
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+132
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+61
-2
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+2
-0
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+6
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+1
-1
No files found.
example/01_gemm/gemm_gemm_xdl_fp16.cpp
View file @
047cee2b
...
@@ -54,25 +54,75 @@ using CElementOp = PassThrough;
...
@@ -54,25 +54,75 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmGemm_Xdl_CShuffle
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmGemm_Xdl_CShuffle
ALayout
,
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
B0Layout
,
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
B1Layout
,
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CLayout
,
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ADataType
,
<
ALayout
,
B0Layout
,
CLayout
,
ADataType
,
B0DataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
B0DataType
,
// clang-format on
CDataType
,
AccDataType
,
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
CShuffleDataType
,
ReferenceGemm
<
ADataType
,
B0DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
B0DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
AccDataType
,
B1DataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
AccDataType
,
B1DataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
// int init_method = 1;
int
init_method
=
1
;
int
init_method
=
3
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
// GEMM shape
// GEMM shape
...
@@ -87,13 +137,13 @@ int main(int argc, char* argv[])
...
@@ -87,13 +137,13 @@ int main(int argc, char* argv[])
// ck::index_t StrideC = 1024;
// ck::index_t StrideC = 1024;
ck
::
index_t
M
=
256
;
ck
::
index_t
M
=
256
;
ck
::
index_t
N
=
256
;
ck
::
index_t
N
=
128
;
ck
::
index_t
K
=
32
;
ck
::
index_t
K
=
32
;
ck
::
index_t
O
=
256
;
ck
::
index_t
O
=
128
;
ck
::
index_t
StrideA
=
2
56
;
ck
::
index_t
StrideA
=
3
2
;
ck
::
index_t
StrideB0
=
2
56
;
ck
::
index_t
StrideB0
=
3
2
;
ck
::
index_t
StrideB1
=
256
;
ck
::
index_t
StrideB1
=
128
;
ck
::
index_t
StrideC
=
256
;
ck
::
index_t
StrideC
=
128
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
@@ -165,14 +215,16 @@ int main(int argc, char* argv[])
...
@@ -165,14 +215,16 @@ int main(int argc, char* argv[])
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
5
,
5
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
5
,
5
});
break
;
break
;
case
2
:
case
2
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
3
<
ADataType
>
{
0.0
,
1.0
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
1
<
ADataType
>
{
1
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B0DataType
>
{
0.0
,
1.0
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
1
<
B0DataType
>
{
1
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B1DataType
>
{
-
0.
5
,
0.
5
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B1DataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
// b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
// b1_n_o.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
...
@@ -182,6 +234,7 @@ int main(int argc, char* argv[])
...
@@ -182,6 +234,7 @@ int main(int argc, char* argv[])
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b0_k_n_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b0_k_n_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b1_n_o_device_buf
.
ToDevice
(
b1_n_o
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
@@ -192,12 +245,15 @@ int main(int argc, char* argv[])
...
@@ -192,12 +245,15 @@ int main(int argc, char* argv[])
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_o_device_buf
.
GetDeviceBuffer
()),
M
,
M
,
N
,
N
,
K
,
K
,
O
,
StrideA
,
StrideA
,
StrideB0
,
StrideB0
,
StrideB1
,
StrideC
,
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -244,6 +300,15 @@ int main(int argc, char* argv[])
...
@@ -244,6 +300,15 @@ int main(int argc, char* argv[])
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// LogRangeAsType<float>(std::cout << "a_m_k: ", a_m_k.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b0_k_n : ", b0_k_n.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b1_n_o : ", b1_n_o.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "c_m_o_device_result : ", c_m_o_device_result.mData, ",") << std::endl;
std
::
cout
<<
"b0_k_n(0, 0) = "
<<
(
float
)
b0_k_n
(
0
,
0
)
<<
", b0_k_n(1, 0) = "
<<
(
float
)
b0_k_n
(
1
,
0
)
<<
", b0_k_n(0, 1) = "
<<
(
float
)
b0_k_n
(
0
,
1
)
<<
", b0_k_n(1, 1) = "
<<
(
float
)
b0_k_n
(
1
,
1
)
<<
std
::
endl
;
return
ck
::
utils
::
check_err
(
c_m_o_device_result
.
mData
,
c_m_o_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
c_m_o_device_result
.
mData
,
c_m_o_host_result
.
mData
)
?
0
:
1
;
}
}
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
047cee2b
...
@@ -25,16 +25,27 @@ constexpr LoopScheduler make_default_loop_scheduler()
...
@@ -25,16 +25,27 @@ constexpr LoopScheduler make_default_loop_scheduler()
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
}
}
// Blockwise gemm supporting both regular XDL output M2_M3_M4_M2 and transposed XDL output
// M2_N2_N3_N4. The latter is similar to "SourceSwap" seen in Tensile
// TODO ANT: rename class to reflect the above fact
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
AK0MK1BlockDesc
,
// could be thread desc
typename
BK0NK1BlockDesc
,
typename
BK0NK1BlockDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
,
bool
TransposeC
=
false
,
index_t
AMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>{}.
K0PerXdlops
,
index_t
BMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{}.
K0PerXdlops
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -46,23 +57,28 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -46,23 +57,28 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
//
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
//
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static
constexpr
index_t
KPerBlock
=
//
static constexpr index_t KPerBlock =
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
//
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
// StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAcc,
// MRepeat * NRepeat * xdlops_gemm.GetRegSizePerXdlops(),
// true>
// c_thread_buf_;
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
FloatAcc
,
MRepeat
*
NRepeat
,
MRepeat
*
NRepeat
,
...
@@ -92,7 +108,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -92,7 +108,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KP
erThread
*
xdlops_a_idx
[
I0
]);
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KP
ack
*
xdlops_a_idx
[
I0
]);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__device__
static
auto
CalculateBThreadOriginDataIndex
()
...
@@ -103,7 +119,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -103,7 +119,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KP
erThread
*
xdlops_b_idx
[
I0
]);
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KP
ack
*
xdlops_b_idx
[
I0
]);
}
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
...
@@ -135,10 +151,30 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -135,10 +151,30 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return
make_tuple
(
c_thread_m
,
c_thread_n
);
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
using
Tuple4
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
#if 0
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
if(!TransposeC && hipThreadIdx_x % 32 < 8)
{
printf("bid %zd tid %zd, a_mma = %d, %d, %d, %d, b_mma = %d, %d, %d, %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
a_origin[Number<0>{}],
a_origin[Number<1>{}],
a_origin[Number<2>{}],
a_origin[Number<3>{}],
b_origin[Number<0>{}],
b_origin[Number<1>{}],
b_origin[Number<2>{}],
b_origin[Number<3>{}]);
}
#endif
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
...
@@ -148,6 +184,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -148,6 +184,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
"wrong!"
);
"wrong!"
);
}
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
(
const
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
&
other
)
:
a_thread_copy_
(
other
.
a_origin
),
b_thread_copy_
(
other
.
b_origin
)
{
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
N
,
M0
,
M1
,
M2
));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
...
@@ -174,6 +231,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -174,6 +231,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
...
@@ -239,33 +311,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -239,33 +311,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_M0_M1_M2_K
()
static
constexpr
AMmaTileDesc
a_block_desc_m0_m1_m2_k
;
{
static
constexpr
BMmaTileDesc
b_block_desc_n0_n1_n2_k
;
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_N0_N1_N2_K
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
// NOTE ANT: a_block_buf for the 2nd gemm is vgpr buffer
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
...
@@ -276,33 +325,65 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -276,33 +325,65 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
// static_for<0, KPerBlock, KPack * xdlops_gemm.K0PerXdlops>{}([&](auto k) {
static_for
<
0
,
KPerThread
/
KPack
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=kpack*[0, 1, 2]
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
// read A
1 without stride
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
// read B
with stride
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}
),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
#if 0
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
if (!TransposeC && hipThreadIdx_x % 32 < 8) {
printf("bid %zd tid %zd, mma tile %d %d %d, a[0:3] = %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, b[0:3] = %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f\n",
hipBlockIdx_x, hipThreadIdx_x, m0.value, n0.value, k.value,
// (float)a_thread_buf[Number<0>{}],
// (float)a_thread_buf[Number<1>{}],
// (float)a_thread_buf[Number<2>{}],
// (float)a_thread_buf[Number<3>{}],
// (float)b_thread_buf[Number<0>{}],
// (float)b_thread_buf[Number<1>{}],
// (float)b_thread_buf[Number<2>{}],
// (float)b_thread_buf[Number<3>{}]
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 0))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 1))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 2))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 3))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 4))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 5))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 6))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 7))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 0))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 1))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 2))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 3))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 4))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 5))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 6))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 7))>{}]
);
}
#endif
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
// xdlops_gemm.K0PerXdlops
// TODO ANT: add appropriate iteration delta
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
i
))
>
{}];
});
});
using
mfma_input_type
=
using
mfma_input_type
=
...
@@ -337,7 +418,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -337,7 +418,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KP
erThread
>
,
Sequence
<
1
,
1
,
1
,
KP
ack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
A_K1
,
A_K1
,
...
@@ -347,16 +428,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -347,16 +428,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KP
erThread
>
,
Sequence
<
1
,
1
,
1
,
KP
ack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
B_K1
,
B_K1
,
B_K1
>
;
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
AThreadCopy
a_thread_copy_
;
//
{CalculateAThreadOriginDataIndex()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
;
//
{CalculateBThreadOriginDataIndex()};
};
};
#if 0
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
...
@@ -584,5 +666,6 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
...
@@ -584,5 +666,6 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
KPack>{};
KPack>{};
}
}
};
};
#endif
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
047cee2b
...
@@ -81,7 +81,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
...
@@ -81,7 +81,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
#if 0
if (std::is_same<Sequence<16,64,2>, BlockSliceLengths>::value)
{
auto s = src_block_slice_origin + thread_data_idx_begin;
auto d = dst_block_slice_origin + thread_data_idx_begin;
printf("bid %zd tid %zd, src origin %d %d %d, dst origin %d %d %d\n",
hipBlockIdx_x, hipThreadIdx_x,
s[Number<0>{}],
s[Number<1>{}],
s[Number<2>{}],
d[Number<0>{}],
d[Number<1>{}],
d[Number<2>{}]);
}
#endif
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_gemm_xdl_cshuffle.hpp
View file @
047cee2b
...
@@ -24,29 +24,36 @@ namespace device {
...
@@ -24,29 +24,36 @@ namespace device {
// version currently has compiler issues with register spill which further causes validation
// version currently has compiler issues with register spill which further causes validation
// failures.
// failures.
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
// B0Layout
typename
B1Layout
,
typename
CLayout
,
typename
CLayout
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
// NOTE: don't distinguish B0/B1 type just yet
typename
CDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
// NOTE: don't distinguish B0/B1 type just yet
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
// Gemm0NPerBlock
index_t
KPerBlock
,
index_t
KPerBlock
,
// Gemm0KPerBlock
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
AK1
,
index_t
AK1
,
index_t
BK1
,
index_t
BK1
,
index_t
B1K1
,
index_t
MPerXDL
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
@@ -61,20 +68,19 @@ template <typename ALayout,
...
@@ -61,20 +68,19 @@ template <typename ALayout,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
bool
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmGemm_Xdl_CShuffle
:
public
DeviceGemm
<
ALayout
,
struct
DeviceGemmGemm_Xdl_CShuffle
:
public
BaseOperator
// TODO ANT: inherit from DeviceGemmGemm subtype
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGemmGemm_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemmGemm_Xdl_CShuffle
;
...
@@ -288,6 +294,44 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -288,6 +294,44 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
}
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b1_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
B1Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
Gemm1NPerBlock
)
*
Gemm1NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
Gemm1KPerBlock
)
*
Gemm1KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
// TODO ANT: implement padding
// not pad N or K
assert
(
KRaw
%
B1K1
==
0
);
const
auto
B1K0
=
KRaw
/
B1K1
;
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
{
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
...
@@ -304,7 +348,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -304,7 +348,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}();
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
Gemm1
NPerBlock
)
*
Gemm1
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
NPad
=
N
-
NRaw
;
...
@@ -348,6 +392,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -348,6 +392,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
...
@@ -362,18 +407,23 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -362,18 +407,23 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
AK1
,
BK1
,
BK1
,
B1K1
,
MPerXDL
,
MPerXDL
,
NPerXDL
,
NPerXDL
,
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -390,6 +440,14 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -390,6 +440,14 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
false
,
BBlockLdsExtraN
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -401,22 +459,27 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -401,22 +459,27 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
{
Argument
(
const
ADataType
*
p_a_grid
,
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
// = ORaw
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
NRaw
,
Gemm1NRaw
,
StrideB1
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
Gemm1NRaw
,
StrideC
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -425,6 +488,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -425,6 +488,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
block_2_ctile_map_
))
{
{
...
@@ -437,9 +501,11 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -437,9 +501,11 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// private:
// private:
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
...
@@ -473,8 +539,10 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -473,8 +539,10 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
#endif
#endif
// TODO ANT: block id to ctilemap should infer acc0tile map
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
arg
.
block_2_ctile_map_
))
{
{
...
@@ -484,13 +552,13 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -484,13 +552,13 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
// TODO ANT: K for gemm1
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
{
const
auto
kernel
=
kernel_gemm_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_gemm_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
...
@@ -500,57 +568,38 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -500,57 +568,38 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
has_main_k_block_loop_
>
;
ave_time
=
return
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
};
// TODO ANT: handle tail loops for gemm0 & gemm1
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_gemm_xdl_cshuffle_v1
<
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -579,6 +628,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -579,6 +628,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
...
@@ -591,12 +641,15 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -591,12 +641,15 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
BDataType
*
p_b
,
const
BDataType
*
p_b1
,
CDataType
*
p_c
,
CDataType
*
p_c
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -604,12 +657,15 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -604,12 +657,15 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
p_b1
,
p_c
,
p_c
,
MRaw
,
MRaw
,
NRaw
,
NRaw
,
KRaw
,
KRaw
,
Gemm1NRaw
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideB1
,
StrideC
,
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -621,25 +677,31 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -621,25 +677,31 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
const
void
*
p_b1
,
void
*
p_c
,
void
*
p_c
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
CElementwiseOperation
c_element_op
)
/*
override
*/
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
MRaw
,
MRaw
,
NRaw
,
NRaw
,
KRaw
,
KRaw
,
Gemm1NRaw
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideB1
,
StrideC
,
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -647,7 +709,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -647,7 +709,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
}
}
// polymorphic
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
/*
override
*/
{
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
}
...
@@ -658,15 +720,18 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -658,15 +720,18 @@ struct DeviceGemmGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffle"
str
<<
"DeviceGemm
Gemm
_Xdl_CShuffle"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
BK1
<<
", "
<<
">"
;
<<
NPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
B1K1
<<
">"
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
047cee2b
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
047cee2b
...
@@ -1145,9 +1145,52 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1145,9 +1145,52 @@ struct ThreadwiseTensorSliceTransfer_v4
src_desc
,
src_data_coord
);
src_desc
,
src_data_coord
);
// copy data from src_buf into src_tmp_vector
// copy data from src_buf into src_tmp_vector
#if 0
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
#else
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
(),
is_src_valid
);
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
(),
is_src_valid
);
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
// apply type convert
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
// if constexpr(StaticBufferTupleOfVector)
// {
// // constexpr auto offset_nd = SrcRefToOriginDisplacement{} + data_to_origin_disp_idx;
// // // offset_nd.foo();
// // constexpr auto offset = src_desc.CalculateOffset(offset_nd);
// // src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// // src_buf.template GetAsType<src_vector_t>(Number<offset>{});
// static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// // constexpr auto src_offset_nd = src_ref_to_origin_disp_idx +
// // data_to_origin_disp_idx + i * src_scalar_step_in_vector;
// // constexpr auto src_offset = src_desc.CalculateOffset(src_offset_nd);
// constexpr auto src_offset = src_desc.CalculateOffset(SrcRefToOriginDisplacement{});
// // SrcData s = src_buf[Number<src_offset>{}];
// SrcData s = src_buf[Number<0>{}];
// // apply type convert
// src_tmp_vector.template AsType<SrcData>()(i) = s;
// });
// }
// else
// {
// src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(),
// is_src_valid);
// }
}
#endif
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
...
@@ -1184,4 +1227,93 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1184,4 +1227,93 @@ struct ThreadwiseTensorSliceTransfer_v4
SrcCoord
src_ref_coord_
;
SrcCoord
src_ref_coord_
;
};
};
// Do NOT involve any tensor coordinates with StaticBuffer
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
// InMemoryDataOperationEnum DstInMemOp,
// index_t DstScalarStrideInVector,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v1r3_Static
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3_Static
()
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! Not divisible"
);
}
template
<
typename
SrcSliceOriginIdx
,
typename
DstSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
"wrong! SliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
()
&&
DstBuffer
::
IsStaticBuffer
(),
"wrong! Buffer need to be StaticBuffer"
);
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
dst_slice_origin_idx
=
to_multi_index
(
DstSliceOriginIdx
{});
// scalar per access on each dim
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
static_assert
(
DstScalarPerVector
==
SpaceFillingCurve
::
ScalarPerVector
,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"
);
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
dst_vector
;
using
dst_vector_t
=
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
src_buf
[
Number
<
src_offset
>
{}];
});
});
}
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
047cee2b
...
@@ -30,6 +30,17 @@ enum struct MfmaInstr
...
@@ -30,6 +30,17 @@ enum struct MfmaInstr
mfma_f64_16x16x4f64
mfma_f64_16x16x4f64
};
};
// template <typename T, bool TransposeC>
// struct mfma_base_type
// {
// template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
// __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
// {
// if constexpr (!TransposeC) T::run(a, b, reg_c);
// else T::run(b, a, reg_c);
// }
// };
template
<
MfmaInstr
instr
>
template
<
MfmaInstr
instr
>
struct
mfma_type
;
struct
mfma_type
;
...
@@ -579,7 +590,11 @@ struct MfmaSelector
...
@@ -579,7 +590,11 @@ struct MfmaSelector
static
constexpr
index_t
GetK1PerXdlops
()
{
return
selected_mfma
.
k_per_blk
;
}
static
constexpr
index_t
GetK1PerXdlops
()
{
return
selected_mfma
.
k_per_blk
;
}
};
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
KPack
>
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
KPack
,
bool
TransposeC
=
false
>
struct
XdlopsGemm
struct
XdlopsGemm
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -612,6 +627,8 @@ struct XdlopsGemm
...
@@ -612,6 +627,8 @@ struct XdlopsGemm
static_assert
(
KPack
%
mfma_instr
.
k_per_blk
==
0
,
"KPack cannot be divided by k_per_blk"
);
static_assert
(
KPack
%
mfma_instr
.
k_per_blk
==
0
,
"KPack cannot be divided by k_per_blk"
);
}
}
// XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
...
@@ -645,6 +662,41 @@ struct XdlopsGemm
...
@@ -645,6 +662,41 @@ struct XdlopsGemm
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
}
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
{
const
auto
M0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
N0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
return
transform_tensor_descriptor
(
c_desc_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_pass_through_transform
(
mfma_instr
.
num_threads_per_blk
),
make_unmerge_transform
(
make_tuple
(
mfma_instr
.
num_groups_per_blk
,
mfma_instr
.
num_input_blks
,
mfma_instr
.
group_size
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}));
}
template
<
typename
CDesc_G_M0_N0_M1_N1_M2_N2
>
template
<
typename
CDesc_G_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
__host__
__device__
static
constexpr
auto
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_G_M0_N0_M1_N1_M2_N2
&
c_desc_g_m0_n0_m1_n1_m2_n2
)
const
CDesc_G_M0_N0_M1_N1_M2_N2
&
c_desc_g_m0_n0_m1_n1_m2_n2
)
...
@@ -698,7 +750,14 @@ struct XdlopsGemm
...
@@ -698,7 +750,14 @@ struct XdlopsGemm
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
if
constexpr
(
!
TransposeC
)
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
}
else
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_b_wave
[
k
],
p_a_wave
[
k
],
p_c_thread
);
}
});
});
}
}
...
...
include/ck/utility/static_buffer.hpp
View file @
047cee2b
...
@@ -61,6 +61,7 @@ struct StaticBufferTupleOfVector
...
@@ -61,6 +61,7 @@ struct StaticBufferTupleOfVector
static
constexpr
auto
s_per_v
=
Number
<
ScalarPerVector
>
{};
static
constexpr
auto
s_per_v
=
Number
<
ScalarPerVector
>
{};
static
constexpr
auto
num_of_v_
=
Number
<
NumOfVector
>
{};
static
constexpr
auto
num_of_v_
=
Number
<
NumOfVector
>
{};
static
constexpr
auto
s_per_buf
=
s_per_v
*
num_of_v_
;
__host__
__device__
constexpr
StaticBufferTupleOfVector
()
:
base
{}
{}
__host__
__device__
constexpr
StaticBufferTupleOfVector
()
:
base
{}
{}
...
@@ -70,6 +71,7 @@ struct StaticBufferTupleOfVector
...
@@ -70,6 +71,7 @@ struct StaticBufferTupleOfVector
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
s_per_buf
;
};
// Get S
// Get S
// i is offset of S
// i is offset of S
template
<
index_t
I
>
template
<
index_t
I
>
...
...
include/ck/utility/tuple_helper.hpp
View file @
047cee2b
...
@@ -78,4 +78,10 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
...
@@ -78,4 +78,10 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
Tuple
<
Number
<
Is
>
...
>
to_tuple
(
Sequence
<
Is
...
>
)
{
return
Tuple
<
Number
<
Is
>
...
>
(
Number
<
Is
>
{}...);
}
}
// namespace ck
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
047cee2b
...
@@ -134,7 +134,7 @@ check_err(const std::vector<T>& out,
...
@@ -134,7 +134,7 @@ check_err(const std::vector<T>& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
128
)
{
{
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment