Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
047cee2b
Commit
047cee2b
authored
Jul 20, 2022
by
Anthony Chang
Browse files
compiles
parent
68b71534
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1035 additions
and
214 deletions
+1035
-214
example/01_gemm/gemm_gemm_xdl_fp16.cpp
example/01_gemm/gemm_gemm_xdl_fp16.cpp
+89
-24
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+145
-62
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+15
-1
include/ck/tensor_operation/gpu/device/device_gemm_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/device_gemm_gemm_xdl_cshuffle.hpp
+134
-69
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_cshuffle_v1.hpp
...operation/gpu/grid/gridwise_gemm_gemm_xdl_cshuffle_v1.hpp
+450
-55
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+132
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+61
-2
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+2
-0
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+6
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+1
-1
No files found.
example/01_gemm/gemm_gemm_xdl_fp16.cpp
View file @
047cee2b
...
...
@@ -54,25 +54,75 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
B0Layout
,
CLayout
,
ADataType
,
B0DataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
B0DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmGemm_Xdl_CShuffle
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
using
ReferenceGemm0Instance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
B0DataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
using
ReferenceGemm1Instance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
AccDataType
,
B1DataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
// int init_method = 1;
int
init_method
=
3
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape
...
...
@@ -87,13 +137,13 @@ int main(int argc, char* argv[])
// ck::index_t StrideC = 1024;
ck
::
index_t
M
=
256
;
ck
::
index_t
N
=
256
;
ck
::
index_t
N
=
128
;
ck
::
index_t
K
=
32
;
ck
::
index_t
O
=
256
;
ck
::
index_t
StrideA
=
2
56
;
ck
::
index_t
StrideB0
=
2
56
;
ck
::
index_t
StrideB1
=
256
;
ck
::
index_t
StrideC
=
256
;
ck
::
index_t
O
=
128
;
ck
::
index_t
StrideA
=
3
2
;
ck
::
index_t
StrideB0
=
3
2
;
ck
::
index_t
StrideB1
=
128
;
ck
::
index_t
StrideC
=
128
;
if
(
argc
==
1
)
{
...
...
@@ -165,14 +215,16 @@ int main(int argc, char* argv[])
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
5
,
5
});
break
;
case
2
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
3
<
ADataType
>
{
0.0
,
1.0
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B0DataType
>
{
0.0
,
1.0
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B1DataType
>
{
-
0.
5
,
0.
5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
1
<
ADataType
>
{
1
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
1
<
B0DataType
>
{
1
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B1DataType
>
{
-
5
,
5
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
// b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
// b1_n_o.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
...
...
@@ -182,6 +234,7 @@ int main(int argc, char* argv[])
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b0_k_n_device_buf
.
ToDevice
(
b0_k_n
.
mData
.
data
());
b1_n_o_device_buf
.
ToDevice
(
b1_n_o
.
mData
.
data
());
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
...
...
@@ -192,12 +245,15 @@ int main(int argc, char* argv[])
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_n_o_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_o_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
O
,
StrideA
,
StrideB0
,
StrideB1
,
StrideC
,
a_element_op
,
b_element_op
,
...
...
@@ -244,6 +300,15 @@ int main(int argc, char* argv[])
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// LogRangeAsType<float>(std::cout << "a_m_k: ", a_m_k.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b0_k_n : ", b0_k_n.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b1_n_o : ", b1_n_o.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "c_m_o_device_result : ", c_m_o_device_result.mData, ",") << std::endl;
std
::
cout
<<
"b0_k_n(0, 0) = "
<<
(
float
)
b0_k_n
(
0
,
0
)
<<
", b0_k_n(1, 0) = "
<<
(
float
)
b0_k_n
(
1
,
0
)
<<
", b0_k_n(0, 1) = "
<<
(
float
)
b0_k_n
(
0
,
1
)
<<
", b0_k_n(1, 1) = "
<<
(
float
)
b0_k_n
(
1
,
1
)
<<
std
::
endl
;
return
ck
::
utils
::
check_err
(
c_m_o_device_result
.
mData
,
c_m_o_host_result
.
mData
)
?
0
:
1
;
}
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
047cee2b
...
...
@@ -25,16 +25,27 @@ constexpr LoopScheduler make_default_loop_scheduler()
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
}
// Blockwise gemm supporting both regular XDL output M2_M3_M4_M2 and transposed XDL output
// M2_N2_N3_N4. The latter is similar to "SourceSwap" seen in Tensile
// TODO ANT: rename class to reflect the above fact
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
AK0MK1BlockDesc
,
// could be thread desc
typename
BK0NK1BlockDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
,
bool
TransposeC
=
false
,
index_t
AMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>{}.
K0PerXdlops
,
index_t
BMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{}.
K0PerXdlops
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -46,23 +57,28 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
KPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
//
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
//
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
//
static constexpr index_t KPerBlock =
//
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
// StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAcc,
// MRepeat * NRepeat * xdlops_gemm.GetRegSizePerXdlops(),
// true>
// c_thread_buf_;
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
...
...
@@ -92,7 +108,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KP
erThread
*
xdlops_a_idx
[
I0
]);
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KP
ack
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
...
...
@@ -103,7 +119,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KP
erThread
*
xdlops_b_idx
[
I0
]);
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KP
ack
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
...
...
@@ -135,10 +151,30 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
using
Tuple4
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
#if 0
if(!TransposeC && hipThreadIdx_x % 32 < 8)
{
printf("bid %zd tid %zd, a_mma = %d, %d, %d, %d, b_mma = %d, %d, %d, %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
a_origin[Number<0>{}],
a_origin[Number<1>{}],
a_origin[Number<2>{}],
a_origin[Number<3>{}],
b_origin[Number<0>{}],
b_origin[Number<1>{}],
b_origin[Number<2>{}],
b_origin[Number<3>{}]);
}
#endif
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
...
...
@@ -148,6 +184,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
"wrong!"
);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
(
const
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
&
other
)
:
a_thread_copy_
(
other
.
a_origin
),
b_thread_copy_
(
other
.
b_origin
)
{
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
N
,
M0
,
M1
,
M2
));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
...
...
@@ -174,6 +231,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
...
...
@@ -239,33 +311,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_M0_M1_M2_K
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_N0_N1_N2_K
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
static
constexpr
AMmaTileDesc
a_block_desc_m0_m1_m2_k
;
static
constexpr
BMmaTileDesc
b_block_desc_n0_n1_n2_k
;
// NOTE ANT: a_block_buf for the 2nd gemm is vgpr buffer
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
...
...
@@ -276,33 +325,65 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
// static_for<0, KPerBlock, KPack * xdlops_gemm.K0PerXdlops>{}([&](auto k) {
static_for
<
0
,
KPerThread
/
KPack
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=kpack*[0, 1, 2]
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A1 without stride
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b
_thread_buf
);
a
_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B with stride
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
#if 0
if (!TransposeC && hipThreadIdx_x % 32 < 8) {
printf("bid %zd tid %zd, mma tile %d %d %d, a[0:3] = %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, b[0:3] = %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f\n",
hipBlockIdx_x, hipThreadIdx_x, m0.value, n0.value, k.value,
// (float)a_thread_buf[Number<0>{}],
// (float)a_thread_buf[Number<1>{}],
// (float)a_thread_buf[Number<2>{}],
// (float)a_thread_buf[Number<3>{}],
// (float)b_thread_buf[Number<0>{}],
// (float)b_thread_buf[Number<1>{}],
// (float)b_thread_buf[Number<2>{}],
// (float)b_thread_buf[Number<3>{}]
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 0))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 1))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 2))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 3))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 4))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 5))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 6))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 7))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 0))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 1))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 2))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 3))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 4))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 5))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 6))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 7))>{}]
);
}
#endif
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
// xdlops_gemm.K0PerXdlops
// TODO ANT: add appropriate iteration delta
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
i
))
>
{}];
});
using
mfma_input_type
=
...
...
@@ -337,7 +418,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KP
erThread
>
,
Sequence
<
1
,
1
,
1
,
KP
ack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
...
...
@@ -347,16 +428,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KP
erThread
>
,
Sequence
<
1
,
1
,
1
,
KP
ack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
AThreadCopy
a_thread_copy_
;
//
{CalculateAThreadOriginDataIndex()};
BThreadCopy
b_thread_copy_
;
//
{CalculateBThreadOriginDataIndex()};
};
#if 0
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
...
...
@@ -584,5 +666,6 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
KPack>{};
}
};
#endif
}
// namespace ck
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
047cee2b
...
...
@@ -81,7 +81,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
#if 0
if (std::is_same<Sequence<16,64,2>, BlockSliceLengths>::value)
{
auto s = src_block_slice_origin + thread_data_idx_begin;
auto d = dst_block_slice_origin + thread_data_idx_begin;
printf("bid %zd tid %zd, src origin %d %d %d, dst origin %d %d %d\n",
hipBlockIdx_x, hipThreadIdx_x,
s[Number<0>{}],
s[Number<1>{}],
s[Number<2>{}],
d[Number<0>{}],
d[Number<1>{}],
d[Number<2>{}]);
}
#endif
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_gemm_xdl_cshuffle.hpp
View file @
047cee2b
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
047cee2b
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
047cee2b
...
...
@@ -1145,9 +1145,52 @@ struct ThreadwiseTensorSliceTransfer_v4
src_desc
,
src_data_coord
);
// copy data from src_buf into src_tmp_vector
#if 0
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
#else
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
(),
is_src_valid
);
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
// apply type convert
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
// if constexpr(StaticBufferTupleOfVector)
// {
// // constexpr auto offset_nd = SrcRefToOriginDisplacement{} + data_to_origin_disp_idx;
// // // offset_nd.foo();
// // constexpr auto offset = src_desc.CalculateOffset(offset_nd);
// // src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// // src_buf.template GetAsType<src_vector_t>(Number<offset>{});
// static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// // constexpr auto src_offset_nd = src_ref_to_origin_disp_idx +
// // data_to_origin_disp_idx + i * src_scalar_step_in_vector;
// // constexpr auto src_offset = src_desc.CalculateOffset(src_offset_nd);
// constexpr auto src_offset = src_desc.CalculateOffset(SrcRefToOriginDisplacement{});
// // SrcData s = src_buf[Number<src_offset>{}];
// SrcData s = src_buf[Number<0>{}];
// // apply type convert
// src_tmp_vector.template AsType<SrcData>()(i) = s;
// });
// }
// else
// {
// src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(),
// is_src_valid);
// }
}
#endif
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
...
...
@@ -1184,4 +1227,93 @@ struct ThreadwiseTensorSliceTransfer_v4
SrcCoord
src_ref_coord_
;
};
// Do NOT involve any tensor coordinates with StaticBuffer
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
// InMemoryDataOperationEnum DstInMemOp,
// index_t DstScalarStrideInVector,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_v1r3_Static
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3_Static
()
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! Not divisible"
);
}
template
<
typename
SrcSliceOriginIdx
,
typename
DstSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
"wrong! SliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
()
&&
DstBuffer
::
IsStaticBuffer
(),
"wrong! Buffer need to be StaticBuffer"
);
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
dst_slice_origin_idx
=
to_multi_index
(
DstSliceOriginIdx
{});
// scalar per access on each dim
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
static_assert
(
DstScalarPerVector
==
SpaceFillingCurve
::
ScalarPerVector
,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"
);
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
dst_vector
;
using
dst_vector_t
=
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
src_buf
[
Number
<
src_offset
>
{}];
});
});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
047cee2b
...
...
@@ -30,6 +30,17 @@ enum struct MfmaInstr
mfma_f64_16x16x4f64
};
// template <typename T, bool TransposeC>
// struct mfma_base_type
// {
// template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
// __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
// {
// if constexpr (!TransposeC) T::run(a, b, reg_c);
// else T::run(b, a, reg_c);
// }
// };
template
<
MfmaInstr
instr
>
struct
mfma_type
;
...
...
@@ -579,7 +590,11 @@ struct MfmaSelector
static
constexpr
index_t
GetK1PerXdlops
()
{
return
selected_mfma
.
k_per_blk
;
}
};
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
KPack
>
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
KPack
,
bool
TransposeC
=
false
>
struct
XdlopsGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -612,6 +627,8 @@ struct XdlopsGemm
static_assert
(
KPack
%
mfma_instr
.
k_per_blk
==
0
,
"KPack cannot be divided by k_per_blk"
);
}
// XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
...
...
@@ -645,6 +662,41 @@ struct XdlopsGemm
Sequence
<
7
>
{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
{
const
auto
M0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
N0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
return
transform_tensor_descriptor
(
c_desc_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_pass_through_transform
(
mfma_instr
.
num_threads_per_blk
),
make_unmerge_transform
(
make_tuple
(
mfma_instr
.
num_groups_per_blk
,
mfma_instr
.
num_input_blks
,
mfma_instr
.
group_size
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}));
}
template
<
typename
CDesc_G_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_G_M0_N0_M1_N1_M2_N2
&
c_desc_g_m0_n0_m1_n1_m2_n2
)
...
...
@@ -698,7 +750,14 @@ struct XdlopsGemm
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
if
constexpr
(
!
TransposeC
)
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
}
else
{
mfma_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_b_wave
[
k
],
p_a_wave
[
k
],
p_c_thread
);
}
});
}
...
...
include/ck/utility/static_buffer.hpp
View file @
047cee2b
...
...
@@ -61,6 +61,7 @@ struct StaticBufferTupleOfVector
static
constexpr
auto
s_per_v
=
Number
<
ScalarPerVector
>
{};
static
constexpr
auto
num_of_v_
=
Number
<
NumOfVector
>
{};
static
constexpr
auto
s_per_buf
=
s_per_v
*
num_of_v_
;
__host__
__device__
constexpr
StaticBufferTupleOfVector
()
:
base
{}
{}
...
...
@@ -70,6 +71,7 @@ struct StaticBufferTupleOfVector
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
s_per_buf
;
};
// Get S
// i is offset of S
template
<
index_t
I
>
...
...
include/ck/utility/tuple_helper.hpp
View file @
047cee2b
...
...
@@ -78,4 +78,10 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
Tuple
<
Number
<
Is
>
...
>
to_tuple
(
Sequence
<
Is
...
>
)
{
return
Tuple
<
Number
<
Is
>
...
>
(
Number
<
Is
>
{}...);
}
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
047cee2b
...
...
@@ -134,7 +134,7 @@ check_err(const std::vector<T>& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
if
(
err_count
<
128
)
{
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment