Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
061009a3
Commit
061009a3
authored
Aug 15, 2023
by
aska-0096
Browse files
Compile pass
parent
d1894bdb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1418 additions
and
208 deletions
+1418
-208
include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
...ensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
+1
-1
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
...block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
+222
-0
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+4
-3
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+49
-204
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+5
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+103
-0
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
.../thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
+1034
-0
No files found.
include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
View file @
061009a3
...
@@ -315,7 +315,7 @@ struct Blockwise_fpAintB_GemmWMMA
...
@@ -315,7 +315,7 @@ struct Blockwise_fpAintB_GemmWMMA
fast_numeric_converter
;
fast_numeric_converter
;
// basic intrinsic to determine loopover direction
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
if
constexpr
(
0
)
{
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
0 → 100644
View file @
061009a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp"
namespace
ck
{
/**
* @brief Blockwise data transfer with dequantization
*
* RunRead would load low-precision data and scale data.
* RunWrite would process dequantization process.
* Assume Scale is identical along K-dimension
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template
<
typename
ThreadGroup
,
typename
SrcElementwiseOperation
,
typename
ScaleElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
BlockSliceLengths
,
typename
BlockScaleSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
ScaleData
,
typename
DstData
,
typename
SrcDesc
,
typename
ScaleDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
ScaleScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
ScaleScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v4r1_dequant
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
scale_thread_slice_lengths
=
BlockScaleSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v4r1_dequant
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_block_slice_origin
,
const
ScaleElementwiseOperation
&
scale_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
DstElementwiseOperation
&
dst_element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
src_element_op
,
scale_desc
,
make_zero_multi_index
<
nDim
>
(),
scale_element_op
,
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
ScaleDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{}
&&
is_same
<
BlockScaleSliceLengths
,
decltype
(
scale_thread_slice_lengths
*
ThreadClusterLengths
{})
>
{}
,
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetScaleSliceOrigin
(
scale_desc
,
scale_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
thread_scratch_id
);
}
}
// With the assumption, scale scratch is always one
template
<
typename
ScaleBuffer
>
__device__
void
RunScaleRead
(
const
ScaleDesc
&
scale_desc
,
const
ScaleBuffer
&
scale_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunScaleRead
(
scale_desc
,
scale_buf
);
}
}
template
<
typename
DstBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst_buf
,
thread_scratch_id
);
}
}
// We don't prefer use this API directly
/*
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
*/
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
// With the assumption, scale buffer don't need move slice window method
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v3r1_dequant
<
decltype
(
thread_slice_lengths
),
decltype
(
scale_thread_slice_lengths
),
SrcElementwiseOperation
,
ScaleElementwiseOperation
,
DstElementwiseOperation
,
DstInMemOp
,
SrcData
,
ScaleData
,
DstData
,
SrcDesc
,
ScaleDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
ScaleScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
ScaleScalarStrideInVector
,
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
061009a3
...
@@ -66,7 +66,7 @@ template <typename ALayout,
...
@@ -66,7 +66,7 @@ template <typename ALayout,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
dequant_v1
>
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
weight_only
>
struct
DeviceFpAintBGemm_Wmma_CShuffle
:
public
DeviceGemm_dequantB
<
ALayout
,
struct
DeviceFpAintBGemm_Wmma_CShuffle
:
public
DeviceGemm_dequantB
<
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
...
@@ -95,7 +95,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -95,7 +95,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
// LDS bypass feature not
checked
.
// LDS bypass feature not
implemented for dequantization pipeline
.
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
true
;
...
@@ -677,7 +677,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -677,7 +677,8 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
dequant_v1
,
"dequant_v1"
}};
{
PipelineVersion
::
dequant_v1
,
"dequant_v1"
},
{
PipelineVersion
::
weight_only
,
"weight_only"
}};
// clang-format off
// clang-format off
str
<<
"DeviceFpAintBGemm_Wmma_CShuffle"
str
<<
"DeviceFpAintBGemm_Wmma_CShuffle"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
061009a3
...
@@ -9,8 +9,9 @@
...
@@ -9,8 +9,9 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_
fpAintB_
gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
@@ -82,6 +83,7 @@ __global__ void
...
@@ -82,6 +83,7 @@ __global__ void
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx1100__))
}
}
// Assume B is Col-Major
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
...
@@ -129,7 +131,7 @@ template <index_t BlockSize,
...
@@ -129,7 +131,7 @@ template <index_t BlockSize,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
dequant_v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
weight_only
>
struct
GridwiseFpAintBGemm_Wmma
struct
GridwiseFpAintBGemm_Wmma
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -252,38 +254,6 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -252,38 +254,6 @@ struct GridwiseFpAintBGemm_Wmma
return
b_block_desc
;
return
b_block_desc
;
}
}
__host__
__device__
static
constexpr
auto
MakeScaleBlockDescriptor
()
{
// Scale [1, N], all K related dimension reduce to 1
constexpr
auto
scale_block_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// K0->N->K1 Per Block
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
I1
),
make_tuple
(
I0
,
I1
,
I0
));
}
else
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
),
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
,
I0
));
}
}();
return
scale_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
{
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
...
@@ -424,47 +394,6 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -424,47 +394,6 @@ struct GridwiseFpAintBGemm_Wmma
return
b_wave_desc
;
return
b_wave_desc
;
}
}
template
<
typename
ScaleBlockDesc_
>
__host__
__device__
static
constexpr
auto
MakeScaleWaveDescriptor
(
const
ScaleBlockDesc_
&
)
{
constexpr
auto
scale_wave_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
ScaleBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
ScaleBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_KRow
=
I1
;
return
transform_tensor_descriptor
(
ScaleBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
else
{
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ScaleBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
ScaleBlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_KRow
=
ScaleBlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_K1
=
ScaleBlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
));
}
}();
return
scale_wave_desc
;
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
// *Caution Here repeat is shuffle repeat
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
...
@@ -613,8 +542,11 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -613,8 +542,11 @@ struct GridwiseFpAintBGemm_Wmma
struct
SharedMemTrait
struct
SharedMemTrait
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and Dequantized B: be careful of DataType
// scale would not put into LDS.
using
LDS_ADataType
=
ADataType
;
using
LDS_BDataType
=
ADataType
;
using
LDS_CDataType
=
CShuffleDataType
;
static
constexpr
auto
max_lds_align
=
K1
;
static
constexpr
auto
max_lds_align
=
K1
;
static
constexpr
auto
a_block_space_size_aligned
=
static
constexpr
auto
a_block_space_size_aligned
=
...
@@ -625,18 +557,13 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -625,18 +557,13 @@ struct GridwiseFpAintBGemm_Wmma
BEnableLds
?
math
::
integer_least_multiple
(
MakeBBlockDescriptor
().
GetElementSpaceSize
(),
BEnableLds
?
math
::
integer_least_multiple
(
MakeBBlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
max_lds_align
)
:
0
;
:
0
;
static
constexpr
auto
scale_block_space_size_aligned
=
BEnableLds
?
math
::
integer_least_multiple
(
MakeScaleBlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
// B would be dequantize to ADataType before enter LDS
// b_lds_offset = LDS size allocated for a in byte / LDS_BDataType
static
constexpr
auto
b_block_space_offset
=
static
constexpr
auto
b_block_space_offset
=
(
a_block_space_offset
+
a_block_space_size_aligned
)
*
sizeof
(
ADataType
)
/
(
a_block_space_offset
+
a_block_space_size_aligned
)
*
sizeof
(
LDS_ADataType
)
/
sizeof
(
BDataType
);
sizeof
(
LDS_BDataType
);
static
constexpr
auto
scale_block_space_offset
=
(
b_block_space_offset
+
b_block_space_size_aligned
)
*
sizeof
(
BDataType
)
/
sizeof
(
ScaleDataType
);
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_space_size
=
static
constexpr
auto
c_shuffle_block_space_size
=
...
@@ -646,10 +573,9 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -646,10 +573,9 @@ struct GridwiseFpAintBGemm_Wmma
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
lds_size
=
static
constexpr
auto
lds_size
=
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
CShuffleDataType
),
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
LDS_CDataType
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
a_block_space_size_aligned
*
sizeof
(
LDS_ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
BDataType
)
+
b_block_space_size_aligned
*
sizeof
(
LDS_BDataType
));
scale_block_space_size_aligned
*
sizeof
(
ScaleDataType
));
};
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
...
@@ -707,7 +633,6 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -707,7 +633,6 @@ struct GridwiseFpAintBGemm_Wmma
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b_block_desc
=
MakeBBlockDescriptor
();
constexpr
auto
b_block_desc
=
MakeBBlockDescriptor
();
constexpr
auto
scale_block_desc
=
MakeScaleBlockDescriptor
();
auto
a_block_trait
=
[
&
](){
auto
a_block_trait
=
[
&
](){
// A matrix blockwise copy
// A matrix blockwise copy
...
@@ -795,35 +720,44 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -795,35 +720,44 @@ struct GridwiseFpAintBGemm_Wmma
{
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
B
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
static_cast
<
A
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
SharedMemTrait
::
b_block_space_size_aligned
);
SharedMemTrait
::
b_block_space_size_aligned
);
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1_dequant
<
ThisThreadBlock
,
BElementwiseOperation
,
/* typename SrcElementwiseOperation, */
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* typename ScaleElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterArrangeOrder
,
/* typename BlockScaleSliceLengths, */
Sequence
<
K0PerBlock
,
NPerBlock
,
I1
>
,
BDataType
,
/* typename ThreadClusterLengths, */
BBlockTransferThreadClusterLengths_K0_N_K1
,
BDataType
,
/* typename ThreadClusterArrangeOrder, */
BBlockTransferThreadClusterArrangeOrder
,
decltype
(
b_grid_desc
),
/* typename SrcData, */
BDataType
,
decltype
(
b_block_desc
),
/* typename ScaleData, */
ScaleDataType
,
BBlockTransferSrcAccessOrder
,
/* typename DstData, */
ADataType
,
Sequence
<
0
,
1
,
2
>
,
/* typename SrcDesc, */
decltype
(
b_grid_desc
),
BBlockTransferSrcVectorDim
,
/* typename ScaleDesc, */
decltype
(
scale_grid_desc
),
2
,
/* typename DstDesc, */
decltype
(
b_block_desc
),
BBlockTransferSrcScalarPerVector
,
/* typename SrcDimAccessOrder, */
BBlockTransferSrcAccessOrder
,
BBlockTransferDstScalarPerVector_K1
,
/* typename DstDimAccessOrder, */
Sequence
<
0
,
1
,
2
>
,
1
,
/* index_t SrcVectorDim, */
BBlockTransferSrcVectorDim
,
1
,
/* index_t DstVectorDim, */
2
,
BThreadTransferSrcResetCoordinateAfterRun
,
/* index_t SrcScalarPerVector, */
BBlockTransferSrcScalarPerVector
,
true
,
/* index_t ScaleScalarPerVector, */
1
,
/* index_t DstScalarPerVector, */
BBlockTransferDstScalarPerVector_K1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t ScaleScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
BThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
,
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
b_grid_desc
,
b_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
scale_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
b_block_desc
,
b_block_desc
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -870,108 +804,22 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -870,108 +804,22 @@ struct GridwiseFpAintBGemm_Wmma
}
}
};
};
auto
scale_block_trait
=
[
&
](){
if
constexpr
(
BEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
scale_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ScaleDataType
*>
(
p_shared
)
+
SharedMemTrait
::
scale_block_space_offset
,
SharedMemTrait
::
scale_block_space_size_aligned
);
auto
scale_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
// Reduce slice length K1 to 1
Sequence
<
K0PerBlock
,
NPerBlock
,
I1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
ScaleDataType
,
ScaleDataType
,
decltype
(
scale_grid_desc
),
decltype
(
scale_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
1
,
1
,
1
,
// no effect
1
,
// no effect
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
scale_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
scale_block_desc
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
return
make_tuple
(
scale_block_buf
,
scale_blockwise_copy
);
}
else
{
// Thread-wise copy
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
auto
scale_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ScaleDataType
>
(
scale_block_desc
.
GetElementSpaceSize
());
auto
scale_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
ScaleDataType
,
decltype
(
scale_grid_desc
),
decltype
(
scale_block_desc
),
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
scale_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
scale_block_buf
,
scale_blockwise_copy
);
}
};
auto
a_block_buf
=
a_block_trait
()[
I0
];
auto
a_block_buf
=
a_block_trait
()[
I0
];
auto
a_blockwise_copy
=
a_block_trait
()[
I1
];
auto
a_blockwise_copy
=
a_block_trait
()[
I1
];
auto
b_block_buf
=
b_block_trait
()[
I0
];
auto
b_block_buf
=
b_block_trait
()[
I0
];
auto
b_blockwise_copy
=
b_block_trait
()[
I1
];
auto
b_blockwise_copy
=
b_block_trait
()[
I1
];
auto
scale_block_buf
=
scale_block_trait
()[
I0
];
auto
scale_blockwise_copy
=
scale_block_trait
()[
I1
];
/*******************************************************************************/
/*******************************************************************************/
// GEMM
// GEMM
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
Blockwise
_fpAintB_
GemmWMMA
<
BlockSize
,
BlockwiseGemmWMMA
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
ADataType
,
//Dequantized
ScaleDataType
,
AccDataType
,
AccDataType
,
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeBWaveDescriptor
(
b_block_desc
)),
decltype
(
MakeBWaveDescriptor
(
b_block_desc
)),
decltype
(
MakeScaleWaveDescriptor
(
scale_block_desc
)),
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
...
@@ -1006,10 +854,7 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -1006,10 +854,7 @@ struct GridwiseFpAintBGemm_Wmma
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
scale_grid_desc
,
scale_grid_desc
,
scale_block_desc
,
scale_blockwise_copy
,
scale_grid_buf
,
scale_grid_buf
,
scale_block_buf
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
KBlockMainLoop
);
KBlockMainLoop
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
061009a3
...
@@ -13,6 +13,7 @@ enum struct PipelineVersion
...
@@ -13,6 +13,7 @@ enum struct PipelineVersion
v1
,
v1
,
v2
,
v2
,
dequant_v1
,
dequant_v1
,
weight_only
,
};
};
template
<
PipelineVersion
PipelineVer
,
template
<
PipelineVersion
PipelineVer
,
...
@@ -41,6 +42,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
...
@@ -41,6 +42,10 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
{
return
GridwiseGemmPipeline_v1_dequant
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
return
GridwiseGemmPipeline_v1_dequant
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
}
}
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
weight_only
)
{
return
GridwiseGemmPipeline_v1_WeightOnly
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
}
else
else
{
{
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
std
::
cerr
<<
"GridwiseGemmPipeline configuration is not available"
<<
std
::
endl
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
061009a3
...
@@ -769,6 +769,109 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, false>
...
@@ -769,6 +769,109 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, false>
}
}
};
};
template
<
index_t
NumPrefetch
,
bool
AEnableLds
,
bool
BEnableLds
>
struct
GridwiseGemmPipeline_v1_WeightOnly
;
template
<
>
struct
GridwiseGemmPipeline_v1_WeightOnly
<
1
,
true
,
true
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
ScaleGridDesc
,
typename
ScaleGridBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
ScaleGridDesc
&
scale_grid_desc
,
const
ScaleGridBuffer
&
scale_grid_buf
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// Global Prefetch Stage 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// Scale read once
b_blockwise_copy
.
RunScaleRead
(
scale_grid_desc
,
scale_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// Dequantization fused in blockwise_copy
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
template
<
index_t
NumPrefetch
>
template
<
index_t
NumPrefetch
>
struct
GridwiseGemmPipelineInterwave_v1
;
struct
GridwiseGemmPipelineInterwave_v1
;
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
0 → 100644
View file @
061009a3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace
ck
{
namespace
detail
{
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template
<
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
>
struct
lambda_scalar_per_access_for_src_and_dst_idle
{
__host__
__device__
constexpr
auto
operator
()(
index_t
i
)
const
{
if
(
i
==
SrcVectorDim
&&
i
==
DstVectorDim
)
{
return
math
::
lcm
(
SrcScalarPerVector
,
DstScalarPerVector
);
}
else
if
(
i
==
SrcVectorDim
)
{
return
SrcScalarPerVector
;
}
else
if
(
i
==
DstVectorDim
)
{
return
DstScalarPerVector
;
}
else
{
return
1
;
}
}
};
}
// namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
// 5. Dequantization happened between read and write.
template
<
typename
SliceLengths
,
typename
ScaleSliceLengths
,
typename
SrcElementwiseOperation
,
typename
ScaleElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
SrcData
,
typename
ScaleData
,
typename
DstData
,
typename
SrcDesc
,
typename
ScaleDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
ScaleScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
ScaleScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool
DstResetCoordinateAfterRun
,
// control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
index_t
NumThreadScratch
=
1
>
struct
ThreadwiseTensorSliceTransfer_v3r1_dequant
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
ScaleCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1_dequant
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_slice_origin
,
const
ScaleElementwiseOperation
&
scale_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
DstElementwiseOperation
&
dst_element_op
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
scale_coord_
(
make_tensor_coordinate
(
scale_desc
,
scale_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
src_element_op_
(
src_element_op
),
scale_element_op_
(
scale_element_op
),
dst_element_op_
(
dst_element_op
)
{
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
__device__
void
SetScaleSliceOrigin
(
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_slice_origin_idx
)
{
scale_coord_
=
make_tensor_coordinate
(
scale_desc
,
scale_slice_origin_idx
);
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// make forward steps
const
auto
src_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
src_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_src_access_lengths
)
>
{}([
&
](
auto
ordered_src_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_src_access_idx
[
i
]
:
ordered_src_access_lengths
[
i
]
-
1
-
ordered_src_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
src_dim_access_order
)
*
src_scalar_per_access
;
}();
constexpr
auto
src_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
src_data_idx
[
i
]
>
{};
},
Number
<
src_data_idx
.
Size
()
>
{});
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
src_vector_t
>(
src_data_idx_seq
,
src_vector_container
.
template
AsType
<
src_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_src_access_idx
[
i
]
<
ordered_src_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_src_access_idx
[
j
]
==
ordered_src_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move src coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
src_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
src_dim_access_order
[
i
]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_desc
,
GetSrcCoordinateResetStep
());
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_step
);
}
}
template
<
typename
ScaleBuffer
>
__device__
void
RunScaleRead
(
const
ScaleDesc
&
scale_desc
,
const
ScaleBuffer
&
scale_buf
)
{
static_assert
(
ScaleBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
ScaleBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
ScaleBuffer
::
type
>
,
remove_cvref_t
<
ScaleData
>>::
value
,
"wrong! ScaleBuffer and ScaleData data type are inconsistent"
);
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
scale_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
ScaleScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
scale_access_lengths
=
SliceLengths
{}
/
scale_scalar_per_access
;
constexpr
auto
scale_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_scale_access_lengths
=
container_reorder_given_new2old
(
scale_access_lengths
,
scale_dim_access_order
);
// make forward steps
const
auto
scale_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
scale_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
scale_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
scale_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
scale_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
scale_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_scale_access_lengths
)
>
{}([
&
](
auto
ordered_scale_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_scale_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_scale_access_lengths
[
j
]
+
ordered_scale_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate scale data index
constexpr
auto
scale_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_scale_access_idx
[
i
]
:
ordered_scale_access_lengths
[
i
]
-
1
-
ordered_scale_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
scale_dim_access_order
)
*
scale_scalar_per_access
;
}();
constexpr
auto
scale_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
scale_data_idx
[
i
]
>
{};
},
Number
<
scale_data_idx
.
Size
()
>
{});
const
bool
is_scale_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
scale_desc
,
scale_coord_
);
using
scale_vector_type
=
vector_type_maker_t
<
ScaleData
,
ScaleScalarPerVector
>
;
using
scale_vector_t
=
typename
scale_vector_type
::
type
;
// copy data from scale_buf into scale_vector_container
auto
scale_vector_container
=
scale_vector_type
{
scale_buf
.
template
Get
<
scale_vector_t
>(
scale_coord_
.
GetOffset
(),
is_scale_valid
)};
// copy data from scale_vector_container into scale_thread_scratch_
scale_thread_scratch_
.
template
SetAsType
<
scale_vector_t
>(
scale_data_idx_seq
,
scale_vector_container
.
template
AsType
<
scale_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_scale_access_idx
[
i
]
<
ordered_scale_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_scale_access_idx
[
j
]
==
ordered_scale_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move scale coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
scale_desc
,
scale_coord_
,
scale_forward_steps
[
scale_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
scale_desc
,
scale_coord_
,
scale_backward_steps
[
scale_dim_access_order
[
i
]]);
}
}
});
});
// don't need to move scale coordinate back to slice origin
/*
if constexpr(SrcResetCoordinateAfterRun)
{
const auto scale_reset_step =
make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep());
move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step);
}
*/
}
template
<
index_t
ThreadScratchId
>
__device__
void
TransferDataFromSrcThreadScratchToDstThreadScratch
(
Number
<
ThreadScratchId
>
thread_scratch_id
)
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
(
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr
index_t
num_src_vector
=
Number
<
DstScalarPerVector
>
{};
constexpr
index_t
num_dst_vector
=
Number
<
SrcScalarPerVector
>
{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert
(
SrcVectorDim
!=
DstVectorDim
,
"wrong"
);
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst_idle
<
SrcVectorDim
,
SrcScalarPerVector
,
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_idx
)
{
constexpr
auto
data_idx
=
access_idx
*
scalar_per_access
;
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
using
src_vector_t
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const
auto
src_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
const
src_vector_t
&
{
// i increment corresponds to movement in DstVectorDim
return
src_thread_scratch_tuple_
[
thread_scratch_id
].
GetVectorTypeReference
(
data_idx_seq
+
i
*
dst_scalar_step_in_vector
);
},
Number
<
num_src_vector
>
{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto
dst_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
dst_vector_t
&
{
// i increment corresponds to movement in SrcVectorDim
return
dst_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq
+
i
*
src_scalar_step_in_vector
);
},
Number
<
num_dst_vector
>
{});
// do data transpose
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
// do fast numeric convert
src_converted_thread_scratch_
.
template
SetAsType
<
SrcThreadConvertedScratch
::
V
>(
access_idx
,
fast_numeric_converter
(
src_thread_scratch_tuple_
[
thread_scratch_id
].
template
GetAsType
<
SrcThreadScratch
::
V
>(
access_idx
)));
});
}
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
// Scale is dynamic, could not implement through element_op.
DstData
dst_v
;
constexpr
auto
scale_idx
=
Sequence
<
I0
,
idx
.
At
(
1
),
I0
>
{};
src_element_op_
(
dst_v
,
src_converted_thread_scratch_
[
idx
]
*
scale_thread_scratch_
[
scale_idx
]);
dst_thread_scratch_
(
idx
)
=
dst_v
;
});
#endif
}
template
<
typename
DstBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
// if there is transpose, it's done here
// TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch
(
thread_scratch_id
);
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
// src scalar per access on each dim
// TODO: don't use this
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// make forward steps
const
auto
dst_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
dst_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_dst_access_idx
[
i
]
:
ordered_dst_access_lengths
[
i
]
-
1
-
ordered_dst_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dst_dim_access_order
)
*
dst_scalar_per_access
;
}();
constexpr
auto
dst_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
dst_data_idx
[
i
]
>
{};
},
Number
<
dst_data_idx
.
Size
()
>
{});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
// copy data from dst_thread_scratch_ into dst_vector_container
auto
dst_vector_container
=
dst_vector_type
{
dst_thread_scratch_
.
template
GetAsType
<
dst_vector_t
>(
dst_data_idx_seq
)};
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
DstData
dst_v
;
// apply DstElementwiseOperation
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
dst_v
;
});
// copy data from dst_vector_container to dst_buf
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
ordered_dst_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_dst_access_idx
[
j
]
==
ordered_dst_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move dst coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dst_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dst_dim_access_order
[
i
]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if
constexpr
(
DstResetCoordinateAfterRun
)
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_desc
,
GetDstCoordinateResetStep
());
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_step
);
}
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_src_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
src_dim_access_order
)
*
src_scalar_per_access
;
}();
//
constexpr
auto
reset_src_data_step
=
[
&
]()
{
Index
reset_src_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_src_data_step_
(
i
)
=
-
src_data_idx
[
i
];
});
return
reset_src_data_step_
;
}();
return
reset_src_data_step
;
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_dst_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
dst_dim_access_order
)
*
dst_scalar_per_access
;
}();
//
constexpr
auto
reset_dst_data_step
=
[
&
]()
{
Index
reset_dst_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_dst_data_step_
(
i
)
=
-
dst_data_idx
[
i
];
});
return
reset_dst_data_step_
;
}();
return
reset_dst_data_step
;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
src_access_lengths
),
Number
<
SrcScalarPerVector
>
{});
// 1st stage of transforms
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
src_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
src_access_lengths_and_vector_length
[
i
],
src_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
src_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetScaleThreadScratchDescriptor
()
{
constexpr
auto
scale_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
ScaleScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
scale_access_lengths
=
SliceLengths
{}
/
scale_scalar_per_access
;
constexpr
auto
scale_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
scale_access_lengths
),
Number
<
ScaleScalarPerVector
>
{});
// 1st stage of transforms
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
scale_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
scale_access_lengths_and_vector_length
[
i
],
scale_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
scale_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetDstThreadScratchDescriptor
()
{
// 1st stage of transforms
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
dst_access_lengths
),
Number
<
DstScalarPerVector
>
{});
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
dst_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
dst_access_lengths_and_vector_length
[
i
],
dst_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
dst_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
private:
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
static
constexpr
auto
scale_thread_scratch_desc_
=
decltype
(
GetScaleThreadScratchDescriptor
()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
/*
template <bool kLastDim>
struct ScaleThreadScratchDesc{};
*/
// Registers, contain raw data loaded from global buffer
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
SrcData
,
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
;
// Registers, contain fast converted data
using
SrcThreadConvertedScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
;
// Registers, contain scale data
using
ScaleThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
ScaleData
,
ScaleScalarPerVector
,
decltype
(
scale_thread_scratch_desc_
),
true
>
;
// Registers, contain dequantized data
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
DstScalarPerVector
,
decltype
(
dst_thread_scratch_desc_
),
true
>
;
using
FastTypeConverter
=
tensor_operation
::
element_wise
::
FastNumericArrayConverter
<
SrcData
,
DstData
,
SrcScalarPerVector
>
;
StaticallyIndexedArray
<
SrcThreadScratch
,
NumThreadScratch
>
src_thread_scratch_tuple_
;
SrcThreadConvertedScratch
src_converted_thread_scratch_
;
ScaleThreadScratch
scale_thread_scratch_
;
DstThreadScratch
dst_thread_scratch_
;
FastTypeConverter
fast_numeric_converter
;
SrcCoord
src_coord_
;
ScaleCoord
scale_coord_
;
DstCoord
dst_coord_
;
const
SrcElementwiseOperation
src_element_op_
;
const
ScaleElementwiseOperation
scale_element_op_
;
const
DstElementwiseOperation
dst_element_op_
;
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment