Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
db843529
"test/split_k/main.cpp" did not exist on "982e59b366ab14509eb04e2061f82ebefe5f87a5"
Commit
db843529
authored
Dec 30, 2024
by
coderfeli
Browse files
fix warnings and revert cmake and fix clang format
parent
5765ba51
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
188 additions
and
207 deletions
+188
-207
CMakeLists.txt
CMakeLists.txt
+4
-0
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
...gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
+42
-47
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
...l/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
+66
-63
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
...id/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
+71
-94
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+3
-2
No files found.
CMakeLists.txt
View file @
db843529
...
...
@@ -516,6 +516,10 @@ include_directories(BEFORE
)
SET
(
BUILD_DEV ON CACHE BOOL
"BUILD_DEV"
)
if
(
BUILD_DEV
)
add_compile_options
(
-Werror
)
add_compile_options
(
-Weverything
)
endif
()
message
(
"CMAKE_CXX_FLAGS:
${
CMAKE_CXX_FLAGS
}
"
)
if
(
"
${
CMAKE_CXX_COMPILER_ID
}
"
MATCHES
"Clang"
)
...
...
cmake/EnableCompilerWarnings.cmake
View file @
db843529
...
...
@@ -66,6 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle.hpp
View file @
db843529
...
...
@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now...
>
struct
BlockwiseGemmXdlops_pipeline_bpreshuffle
<
BlockGemmPipelineScheduler
::
Intrawave
,
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
BlockSize
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
ATileDesc
,
BTileDesc
,
AMmaTileDesc
,
BMmaTileDesc
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
:
BlockwiseGemmXdlops_pipeline_base
<
BlockSize
,
ADataType
,
BDataType
,
...
...
@@ -137,7 +137,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
using
Base
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
;
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
...
...
@@ -271,10 +270,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
...
...
@@ -285,10 +282,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
ABlockBuffer
&
a_block_buf1
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
...
...
@@ -296,8 +291,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
__builtin_amdgcn_sched_barrier
(
0
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
...
...
@@ -348,14 +341,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
0
>
{}
>
();
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
0
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
...
...
@@ -399,8 +393,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
1
>
{}
>
();
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
1
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
...
...
@@ -449,25 +444,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
0
>
{}
>
();
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
0
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -477,11 +471,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf1
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf1
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
});
});
...
...
@@ -491,8 +485,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
ComputeDataType
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
1
>
{}
>
();
vector_type
<
ComputeDataType
,
KPack
>
b_thread_vec
=
b_blockwise_copy
.
template
GetSrcThreadScratchIdx
<
Sequence
<
0
,
k0
,
0
>,
Number
<
1
>
{}
>
();
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
db843529
...
...
@@ -112,7 +112,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
template
<
typename
SeqIdx
,
index_t
ThreadScratchId
=
0
>
__device__
constexpr
auto
GetSrcThreadScratchIdx
()
{
return
threadwise_transfer_
.
template
GetSrcThreadScratchIdx
<
SeqIdx
,
ThreadScratchId
>();
return
threadwise_transfer_
.
template
GetSrcThreadScratchIdx
<
SeqIdx
,
ThreadScratchId
>();
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
db843529
...
...
@@ -67,55 +67,57 @@ template <typename ALayout,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
LDSTypeA
=
ComputeTypeA
,
typename
LDSTypeB
=
ComputeTypeB
>
struct
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
:
public
DeviceGemmMultiD_Xdl_CShuffle_V3
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
GemmAccDataType
,
CShuffleDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVectors
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
,
LDSTypeA
,
LDSTypeB
>
typename
LDSTypeB
=
ComputeTypeB
>
struct
DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
:
public
DeviceGemmMultiD_Xdl_CShuffle_V3
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
GemmAccDataType
,
CShuffleDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVectors
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
,
LDSTypeA
,
LDSTypeB
>
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -172,7 +174,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
LDSTypeA
,
LDSTypeB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
...
...
@@ -267,7 +268,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right now");
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
// now");
if
(
has_main_k_block_loop
)
{
// Tail number always full
...
...
@@ -284,11 +287,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
...
...
@@ -298,7 +301,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
}
}
else
{
{
if
(
arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
...
...
@@ -310,11 +313,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
...
...
@@ -437,4 +440,4 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
\ No newline at end of file
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp
View file @
db843529
...
...
@@ -125,23 +125,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static
constexpr
auto
CShuffleBlockTransferScalarPerVector_NPerBlock
=
CDEShuffleBlockTransferScalarPerVectors
{}[
I0
];
// K1 should be Number<...>
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BlockSizeNumber
=
Number
<
BlockSize
>
{};
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
mfma_selector
=
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
,
ComputeTypeB
>
;
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
mfma_selector
::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KLane
=
mfma_selector
::
GetKPerXdlops
()
/
mfma_selector
::
GetK1PerXdlops
();
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
mfma_selector
::
selected_mfma
.
k_per_blk
);
static
constexpr
index_t
KLane
=
mfma_selector
::
GetKPerXdlops
()
/
mfma_selector
::
GetK1PerXdlops
();
static
constexpr
index_t
KRepeat
=
KPerBlock
/
KLane
/
KPack
;
static
constexpr
index_t
NLane
=
NPerXdl
;
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static
constexpr
index_t
NLane
=
NPerXdl
;
static
constexpr
index_t
NWave
=
NPerBlock
/
NPerXdl
/
NXdlPerWave
;
static_assert
(
NLane
*
NWave
*
KLane
==
BlockSize
);
static_assert
(
NXdlPerWave
==
1
,
"only 1 validated now, tbd next week"
);
static
constexpr
auto
MakeDsGridPointer
()
{
return
generate_tuple
(
...
...
@@ -320,12 +322,10 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__
__device__
static
auto
MakeBGridDescriptor_Preshuffled
(
index_t
N0
,
index_t
K0
)
{
constexpr
index_t
NkSwizzle
=
BlockSize
*
KPack
;
constexpr
index_t
NkSwizzle
=
BlockSize
*
KPack
;
constexpr
index_t
NkSwizzleNumber
=
Number
<
NkSwizzle
>
{};
return
make_naive_tensor_descriptor
(
make_tuple
(
N0
,
K0
,
NkSwizzleNumber
),
make_tuple
(
K0
*
NkSwizzle
,
NkSwizzleNumber
,
I1
)
);
return
make_naive_tensor_descriptor
(
make_tuple
(
N0
,
K0
,
NkSwizzleNumber
),
make_tuple
(
K0
*
NkSwizzle
,
NkSwizzleNumber
,
I1
));
}
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
...
...
@@ -423,9 +423,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__host__
__device__
static
constexpr
auto
MakeBMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWave
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
template
<
typename
ELayout
>
...
...
@@ -943,7 +941,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
...
...
@@ -955,44 +952,40 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
using
BlockwiseGemmPipe
=
remove_cvref_t
<
decltype
(
BlockwiseGemmXdlops_pipeline_bpreshuffle
<
BlkGemmPipeSched
,
BlockSize
,
LDSTypeA
,
LDSTypeB
,
ComputeTypeA
,
AccDataType
,
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()),
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()),
decltype
(
MakeAMmaTileDescriptor_M0_M1_M2_K
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
())),
decltype
(
MakeBMmaTileDescriptor_N0_N1_N2_K
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
())),
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
{})
>
;
using
BlockwiseGemmPipe
=
remove_cvref_t
<
decltype
(
BlockwiseGemmXdlops_pipeline_bpreshuffle
<
BlkGemmPipeSched
,
BlockSize
,
LDSTypeA
,
LDSTypeB
,
ComputeTypeA
,
AccDataType
,
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()),
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()),
decltype
(
MakeAMmaTileDescriptor_M0_M1_M2_K
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
())),
decltype
(
MakeBMmaTileDescriptor_N0_N1_N2_K
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
())),
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
{})
>
;
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
// constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
// b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
...
...
@@ -1259,8 +1252,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bpreshuffled
=
MakeBGridDescriptor_Preshuffled
(
problem
.
BN0Shuffled
,
problem
.
BK0Shuffled
);
const
auto
b_grid_desc_bpreshuffled
=
MakeBGridDescriptor_Preshuffled
(
problem
.
BN0Shuffled
,
problem
.
BK0Shuffled
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
<
CLayout
>
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
...
...
@@ -1294,10 +1287,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
(
NPerBlock
/
NLane
/
NWave
))
;
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
__builtin_amdgcn_readfirstlane
(
block_n_id
*
(
NPerBlock
/
NLane
/
NWave
));
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
...
...
@@ -1339,51 +1329,42 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// using BThreadClusterLengths = Sequence<1, 1, BlockSize>;
// using BBlockTransferClusterArrangeOrder = Sequence<0, 1, 2>;
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
KRepeat
,
KPack
*
BlockSize
>
,
Sequence
<
1
,
1
,
BlockSize
>
,
//BThreadClusterLengths,
Sequence
<
0
,
1
,
2
>
,
//BBlockTransferClusterArrangeOrder,
BDataType
,
LDSTypeB
,
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_block_desc_bk0_n_bk1
),
Sequence
<
0
,
1
,
2
>
,
//BBlockTransferSrcAccessOrder,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
2
>
(
b_grid_desc_bpreshuffled
,
make_multi_index
(
n_block_data_idx_on_grid
,
0
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
KRepeat
,
KPack
*
BlockSize
>
,
Sequence
<
1
,
1
,
BlockSize
>
,
// BThreadClusterLengths,
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferClusterArrangeOrder,
BDataType
,
LDSTypeB
,
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_block_desc_bk0_n_bk1
),
Sequence
<
0
,
1
,
2
>
,
// BBlockTransferSrcAccessOrder,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
2
>
(
b_grid_desc_bpreshuffled
,
make_multi_index
(
n_block_data_idx_on_grid
,
0
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
// Cast after lds
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
a_block_buf1
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeB
*>
(
p_shared
)
+
a_block_space_size_aligned
*
sizeof
(
LDSTypeA
)
/
sizeof
(
LDSTypeB
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
KRepeat
,
0
);
...
...
@@ -1404,10 +1385,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_block_buf1
,
a_block_slice_copy_step
,
b_grid_desc_bpreshuffled
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
...
...
@@ -1419,7 +1398,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
...
...
@@ -1672,7 +1650,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
db843529
...
...
@@ -268,12 +268,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}
template
<
typename
SeqIdx
,
index_t
ThreadScratchId
=
0
>
__device__
constexpr
auto
GetSrcThreadScratchIdx
(
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
__device__
constexpr
auto
GetSrcThreadScratchIdx
(
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
using
vector_t
=
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
return
src_thread_scratch_tuple_
(
thread_scratch_id
).
template
GetAsType
<
vector_t
>(
SeqIdx
{});
}
template
<
index_t
ThreadScratchId
>
__device__
void
TransferDataFromSrcThreadScratchToDstThreadScratch
(
Number
<
ThreadScratchId
>
thread_scratch_id
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment