Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cc0ffeb7
Unverified
Commit
cc0ffeb7
authored
Aug 16, 2023
by
Haocong WANG
Committed by
GitHub
Aug 16, 2023
Browse files
Merge pull request #851 from ROCmSoftwarePlatform/perf_opt_fpAintB
New implementation fpAintB
parents
3ba0f0d7
bf75259f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1361 additions
and
967 deletions
+1361
-967
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+6
-6
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
...ensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
+0
-624
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
...block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
+223
-0
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
...or_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
+3
-3
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+49
-204
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+3
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+8
-124
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
.../thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
+1066
-0
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
cc0ffeb7
...
@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BLayout
,
BLayout
,
CLayout
,
CLayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -35,16 +35,16 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -35,16 +35,16 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp
,
BElementOp
,
CElementOp
,
CElementOp
,
GemmDefault
,
GemmDefault
,
2
,
// Prefetch stage
1
,
// Prefetch stage
128
,
// BlockSize
128
,
// BlockSize
128
,
// MPerBlock
64
,
// MPerBlock
64
,
// NPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
64
,
// KPerBlock
8
,
// K1
8
,
// K1
16
,
// MPerWmma
16
,
// MPerWmma
16
,
// NPerWmma
16
,
// NPerWmma
4
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
2
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp
View file @
cc0ffeb7
...
@@ -21,7 +21,7 @@ using QuantDataType = int8_t;
...
@@ -21,7 +21,7 @@ using QuantDataType = int8_t;
using
BDataType
=
uint8_t
;
using
BDataType
=
uint8_t
;
using
ScaleDataType
=
ck
::
half_t
;
using
ScaleDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
ck
::
half_
t
;
using
CShuffleDataType
=
floa
t
;
using
CDataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp
deleted
100644 → 0
View file @
3ba0f0d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#define CK_MNK_LOOP
namespace
ck
{
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
BDataType
,
typename
ScaleDataType
,
typename
FloatAcc
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
ScaleBlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
bool
TransposeC
=
false
>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct
Blockwise_fpAintB_GemmWMMA
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static
constexpr
index_t
WaveSize
=
32
;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static
constexpr
index_t
A_KRow
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow
=
BEnableLds
?
1
:
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
// As Float DataType
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
ADataType
,
ADataType
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
wmma_gemm
.
GetRegSizePerWmma
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
if
constexpr
(
AEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
0
,
WMMA_a_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
if
constexpr
(
BEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
0
,
WMMA_b_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk
();
constexpr
auto
mrepeat_mwave_mperWMMA_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperWMMA_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperWMMA_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperWMMA_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex7D
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
waveId_m
,
blk_idx
[
I0
],
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
using
Tuple6
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
Blockwise_fpAintB_GemmWMMA
(
Tuple6
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple6
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
),
scale_thread_copy_
(
b_origin
)
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
// Thread level, register decriptor. Vector-write
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
AccStride
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
MAccVgprs
),
make_tuple
(
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
AccStride
));
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Provide dimension size
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
static
constexpr
ScaleBlockDesc
scale_block_desc_1_n0_n1_n2_1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
ScaleBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
const
ScaleBlockBuffer
&
scale_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
auto
scale_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ScaleDataType
>
(
scale_thread_desc_
.
GetElementSpaceSize
());
// auto converted_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
// b_thread_desc_.GetElementSpaceSize());
tensor_operation
::
element_wise
::
FastNumericArrayConverter
<
BDataType
,
ADataType
,
WmmaK
>
fast_numeric_converter
;
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read weight scale
scale_thread_copy_
.
Run
(
scale_block_desc_1_n0_n1_n2_1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
scale_block_buf
,
scale_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
scale_thread_buf
);
vector_type
<
BDataType
,
WmmaK
>
b_int_vec
;
vector_type
<
ADataType
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
b_int_vec
.
template
AsType
<
BDataType
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
0
,
(
i
/
B_K1
)
%
B_KRow
,
0
,
i
%
B_K1
))
>
{}];
});
// convert B from uint8 to fp16, multiply scale
b_thread_vec
=
fast_numeric_converter
(
b_int_vec
);
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
ADataType
>()(
i
)
=
scale_thread_buf
[
n0
]
*
b_thread_vec
.
template
AsType
<
ADataType
>()(
i
);
});
vector_type
<
ADataType
,
WmmaK
>
a_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
ADataType
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
/
A_KRow
,
m0
,
0
,
(
i
/
A_K1
)
%
A_KRow
,
0
,
i
%
A_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
ADataType
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
ADataType
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
else
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
// k=0,kpack*1, ..
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read weight scale
scale_thread_copy_
.
Run
(
scale_block_desc_1_n0_n1_n2_1
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
scale_block_buf
,
scale_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
scale_thread_buf
);
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
BDataType
,
WmmaK
>
b_int_vec
;
vector_type
<
ADataType
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
b_int_vec
.
template
AsType
<
BDataType
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
0
,
(
i
/
B_K1
)
%
B_KRow
,
0
,
i
%
B_K1
))
>
{}];
});
// convert B from uint8 to fp16, multiply scale
b_thread_vec
=
fast_numeric_converter
(
b_int_vec
);
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
ADataType
>()(
i
)
=
scale_thread_buf
[
n0
]
*
b_thread_vec
.
template
AsType
<
ADataType
>()(
i
);
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
vector_type
<
ADataType
,
WmmaK
>
a_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
ADataType
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
/
A_KRow
,
m0
,
0
,
(
i
/
A_K1
)
%
A_KRow
,
0
,
i
%
A_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
ADataType
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
ADataType
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
}
protected:
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
WmmaK
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
Number
<
A_KRow
>
{},
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
A_K1
*
A_KRow
>
{},
Number
<
WmmaK
>
{},
Number
<
A_K1
*
A_KRow
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
WmmaK
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
*
B_KRow
>
{},
Number
<
WmmaK
>
{},
Number
<
B_K1
*
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
static
constexpr
auto
scale_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
WmmaK
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
I1
),
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
));
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
template
<
bool
EnableLds
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
ADataType
,
ADataType
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
WmmaK
/
A_K1
/
A_KRow
,
1
,
1
,
A_KRow
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
A_K1
>
;
};
template
<
>
struct
AThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
ADataType
,
ADataType
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
WmmaK
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
0x76543210
,
0xfedcba98
,
TransposeC
?
false
:
true
>
;
};
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
BDataType
,
BDataType
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
WmmaK
/
B_K1
/
B_KRow
,
1
,
1
,
B_KRow
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
BDataType
,
BDataType
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
WmmaK
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
0x76543210
,
0xfedcba98
,
TransposeC
?
true
:
false
>
;
};
template
<
bool
EnableLds
>
struct
ScaleThreadCopySelector
;
template
<
>
struct
ScaleThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
ScaleDataType
,
ScaleDataType
,
decltype
(
scale_block_desc_1_n0_n1_n2_1
),
decltype
(
scale_thread_desc_
),
Sequence
<
WmmaK
/
B_K1
/
B_KRow
,
1
,
1
,
B_KRow
,
1
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
,
1
>
;
};
template
<
>
struct
ScaleThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
ScaleDataType
,
ScaleDataType
,
decltype
(
scale_block_desc_1_n0_n1_n2_1
),
decltype
(
scale_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
WmmaK
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
typename
ScaleThreadCopySelector
<
BEnableLds
>::
type
scale_thread_copy_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp
0 → 100644
View file @
cc0ffeb7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp"
namespace
ck
{
/**
* @brief Blockwise data transfer with dequantization
*
* RunRead would load low-precision data and scale data.
* RunWrite would process dequantization process.
* Assume Scale is identical along K-dimension
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template
<
typename
ThreadGroup
,
typename
SrcElementwiseOperation
,
typename
ScaleElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
BlockSliceLengths
,
typename
BlockScaleSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
ScaleData
,
typename
DstData
,
typename
SrcDesc
,
typename
ScaleDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
ScaleScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
ScaleScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v4r1_dequant
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
static
constexpr
auto
scale_thread_slice_lengths
=
BlockScaleSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v4r1_dequant
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_block_slice_origin
,
const
ScaleElementwiseOperation
&
scale_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
DstElementwiseOperation
&
dst_element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
src_element_op
,
scale_desc
,
make_zero_multi_index
<
nDim
>
(),
scale_element_op
,
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
ScaleDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{}
&&
is_same
<
BlockScaleSliceLengths
,
decltype
(
scale_thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetScaleSliceOrigin
(
scale_desc
,
scale_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
thread_scratch_id
);
}
}
// With the assumption, scale scratch is always one
template
<
typename
ScaleBuffer
>
__device__
void
RunScaleRead
(
const
ScaleDesc
&
scale_desc
,
const
ScaleBuffer
&
scale_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunScaleRead
(
scale_desc
,
scale_buf
);
}
}
template
<
typename
DstBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst_buf
,
thread_scratch_id
);
}
}
// We don't prefer use this API directly
/*
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
*/
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
// With the assumption, scale buffer don't need move slice window method
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v3r1_dequant
<
decltype
(
thread_slice_lengths
),
decltype
(
scale_thread_slice_lengths
),
SrcElementwiseOperation
,
ScaleElementwiseOperation
,
DstElementwiseOperation
,
DstInMemOp
,
SrcData
,
ScaleData
,
DstData
,
SrcDesc
,
ScaleDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
ScaleScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
ScaleScalarStrideInVector
,
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp
View file @
cc0ffeb7
...
@@ -66,7 +66,7 @@ template <typename ALayout,
...
@@ -66,7 +66,7 @@ template <typename ALayout,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
dequant_v1
>
ck
::
PipelineVersion
PipelineVer
=
ck
::
PipelineVersion
::
weight_only
>
struct
DeviceFpAintBGemm_Wmma_CShuffle
:
public
DeviceGemm_dequantB
<
ALayout
,
struct
DeviceFpAintBGemm_Wmma_CShuffle
:
public
DeviceGemm_dequantB
<
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
...
@@ -95,7 +95,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -95,7 +95,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
// LDS bypass feature not
checked
.
// LDS bypass feature not
implemented for dequantization pipeline
.
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
true
;
...
@@ -677,7 +677,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
...
@@ -677,7 +677,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
v2
,
"v2"
},
{
PipelineVersion
::
dequant_v1
,
"dequant_v1
"
}};
{
PipelineVersion
::
weight_only
,
"weight_only
"
}};
// clang-format off
// clang-format off
str
<<
"DeviceFpAintBGemm_Wmma_CShuffle"
str
<<
"DeviceFpAintBGemm_Wmma_CShuffle"
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
cc0ffeb7
...
@@ -405,10 +405,10 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
...
@@ -405,10 +405,10 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
half_2
[
1
]
=
__builtin_amdgcn_perm
(
fp16_adder
,
uint8_4
,
byte_selector_23
);
half_2
[
1
]
=
__builtin_amdgcn_perm
(
fp16_adder
,
uint8_4
,
byte_selector_23
);
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]
\n
"
asm
volatile
(
"v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
:
"=v"
(
half_2
[
0
])
:
"=v"
(
half_2
[
0
])
:
"v"
(
half_2
[
0
]),
"s"
(
I8s_TO_F16s_MAGIC_NUM
));
:
"v"
(
half_2
[
0
]),
"s"
(
I8s_TO_F16s_MAGIC_NUM
));
asm
volatile
(
"v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]
\n
"
asm
volatile
(
"v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
:
"=v"
(
half_2
[
1
])
:
"=v"
(
half_2
[
1
])
:
"v"
(
half_2
[
1
]),
"s"
(
I8s_TO_F16s_MAGIC_NUM
));
:
"v"
(
half_2
[
1
]),
"s"
(
I8s_TO_F16s_MAGIC_NUM
));
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
cc0ffeb7
...
@@ -9,8 +9,9 @@
...
@@ -9,8 +9,9 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_
fpAintB_
gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
@@ -82,6 +83,7 @@ __global__ void
...
@@ -82,6 +83,7 @@ __global__ void
#endif // end of if (defined(__gfx1100__))
#endif // end of if (defined(__gfx1100__))
}
}
// Assume B is Col-Major
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
...
@@ -129,7 +131,7 @@ template <index_t BlockSize,
...
@@ -129,7 +131,7 @@ template <index_t BlockSize,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
NumGemmKPrefetchStage
=
1
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
dequant_v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
weight_only
>
struct
GridwiseFpAintBGemm_Wmma
struct
GridwiseFpAintBGemm_Wmma
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -252,38 +254,6 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -252,38 +254,6 @@ struct GridwiseFpAintBGemm_Wmma
return
b_block_desc
;
return
b_block_desc
;
}
}
__host__
__device__
static
constexpr
auto
MakeScaleBlockDescriptor
()
{
// Scale [1, N], all K related dimension reduce to 1
constexpr
auto
scale_block_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// K0->N->K1 Per Block
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
I1
),
make_tuple
(
I0
,
I1
,
I0
));
}
else
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
),
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
,
I0
));
}
}();
return
scale_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
{
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
...
@@ -424,47 +394,6 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -424,47 +394,6 @@ struct GridwiseFpAintBGemm_Wmma
return
b_wave_desc
;
return
b_wave_desc
;
}
}
template
<
typename
ScaleBlockDesc_
>
__host__
__device__
static
constexpr
auto
MakeScaleWaveDescriptor
(
const
ScaleBlockDesc_
&
)
{
constexpr
auto
scale_wave_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
ScaleBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
ScaleBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_KRow
=
I1
;
return
transform_tensor_descriptor
(
ScaleBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
else
{
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ScaleBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
ScaleBlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_KRow
=
ScaleBlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_K1
=
ScaleBlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
I0
,
I1
,
I0
,
I0
,
I0
,
I0
));
}
}();
return
scale_wave_desc
;
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
// *Caution Here repeat is shuffle repeat
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
...
@@ -613,8 +542,11 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -613,8 +542,11 @@ struct GridwiseFpAintBGemm_Wmma
struct
SharedMemTrait
struct
SharedMemTrait
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and Dequantized B: be careful of DataType
// scale would not put into LDS.
using
LDS_ADataType
=
ADataType
;
using
LDS_BDataType
=
ADataType
;
using
LDS_CDataType
=
CShuffleDataType
;
static
constexpr
auto
max_lds_align
=
K1
;
static
constexpr
auto
max_lds_align
=
K1
;
static
constexpr
auto
a_block_space_size_aligned
=
static
constexpr
auto
a_block_space_size_aligned
=
...
@@ -625,18 +557,13 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -625,18 +557,13 @@ struct GridwiseFpAintBGemm_Wmma
BEnableLds
?
math
::
integer_least_multiple
(
MakeBBlockDescriptor
().
GetElementSpaceSize
(),
BEnableLds
?
math
::
integer_least_multiple
(
MakeBBlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
max_lds_align
)
:
0
;
:
0
;
static
constexpr
auto
scale_block_space_size_aligned
=
BEnableLds
?
math
::
integer_least_multiple
(
MakeScaleBlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
// B would be dequantize to ADataType before enter LDS
// b_lds_offset = LDS size allocated for a in byte / LDS_BDataType
static
constexpr
auto
b_block_space_offset
=
static
constexpr
auto
b_block_space_offset
=
(
a_block_space_offset
+
a_block_space_size_aligned
)
*
sizeof
(
ADataType
)
/
(
a_block_space_offset
+
a_block_space_size_aligned
)
*
sizeof
(
LDS_ADataType
)
/
sizeof
(
BDataType
);
sizeof
(
LDS_BDataType
);
static
constexpr
auto
scale_block_space_offset
=
(
b_block_space_offset
+
b_block_space_size_aligned
)
*
sizeof
(
BDataType
)
/
sizeof
(
ScaleDataType
);
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_shuffle_block_space_size
=
static
constexpr
auto
c_shuffle_block_space_size
=
...
@@ -646,10 +573,9 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -646,10 +573,9 @@ struct GridwiseFpAintBGemm_Wmma
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
c_shuffle_block_space_offset
=
0
;
static
constexpr
auto
lds_size
=
static
constexpr
auto
lds_size
=
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
CShuffleDataType
),
math
::
max
(
c_shuffle_block_space_size
*
sizeof
(
LDS_CDataType
),
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
a_block_space_size_aligned
*
sizeof
(
LDS_ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
BDataType
)
+
b_block_space_size_aligned
*
sizeof
(
LDS_BDataType
));
scale_block_space_size_aligned
*
sizeof
(
ScaleDataType
));
};
};
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
...
@@ -707,7 +633,6 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -707,7 +633,6 @@ struct GridwiseFpAintBGemm_Wmma
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b_block_desc
=
MakeBBlockDescriptor
();
constexpr
auto
b_block_desc
=
MakeBBlockDescriptor
();
constexpr
auto
scale_block_desc
=
MakeScaleBlockDescriptor
();
auto
a_block_trait
=
[
&
](){
auto
a_block_trait
=
[
&
](){
// A matrix blockwise copy
// A matrix blockwise copy
...
@@ -795,35 +720,44 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -795,35 +720,44 @@ struct GridwiseFpAintBGemm_Wmma
{
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
B
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
static_cast
<
A
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
SharedMemTrait
::
b_block_space_size_aligned
);
SharedMemTrait
::
b_block_space_size_aligned
);
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1_dequant
<
ThisThreadBlock
,
BElementwiseOperation
,
/* typename SrcElementwiseOperation, */
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* typename ScaleElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
/* typename BlockSliceLengths, */
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterArrangeOrder
,
/* typename BlockScaleSliceLengths, */
Sequence
<
K0PerBlock
,
NPerBlock
,
I1
>
,
BDataType
,
/* typename ThreadClusterLengths, */
BBlockTransferThreadClusterLengths_K0_N_K1
,
BDataType
,
/* typename ThreadClusterArrangeOrder, */
BBlockTransferThreadClusterArrangeOrder
,
decltype
(
b_grid_desc
),
/* typename SrcData, */
BDataType
,
decltype
(
b_block_desc
),
/* typename ScaleData, */
ScaleDataType
,
BBlockTransferSrcAccessOrder
,
/* typename DstData, */
ADataType
,
Sequence
<
0
,
1
,
2
>
,
/* typename SrcDesc, */
decltype
(
b_grid_desc
),
BBlockTransferSrcVectorDim
,
/* typename ScaleDesc, */
decltype
(
scale_grid_desc
),
2
,
/* typename DstDesc, */
decltype
(
b_block_desc
),
BBlockTransferSrcScalarPerVector
,
/* typename SrcDimAccessOrder, */
BBlockTransferSrcAccessOrder
,
BBlockTransferDstScalarPerVector_K1
,
/* typename DstDimAccessOrder, */
Sequence
<
0
,
1
,
2
>
,
1
,
/* index_t SrcVectorDim, */
BBlockTransferSrcVectorDim
,
1
,
/* index_t DstVectorDim, */
2
,
BThreadTransferSrcResetCoordinateAfterRun
,
/* index_t SrcScalarPerVector, */
BBlockTransferSrcScalarPerVector
,
true
,
/* index_t ScaleScalarPerVector, */
1
,
/* index_t DstScalarPerVector, */
BBlockTransferDstScalarPerVector_K1
,
/* index_t SrcScalarStrideInVector, */
1
,
/* index_t ScaleScalarStrideInVector, */
1
,
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
BThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
,
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
b_grid_desc
,
b_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
scale_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
b_block_desc
,
b_block_desc
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -870,108 +804,22 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -870,108 +804,22 @@ struct GridwiseFpAintBGemm_Wmma
}
}
};
};
auto
scale_block_trait
=
[
&
](){
if
constexpr
(
BEnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
scale_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ScaleDataType
*>
(
p_shared
)
+
SharedMemTrait
::
scale_block_space_offset
,
SharedMemTrait
::
scale_block_space_size_aligned
);
auto
scale_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
// Reduce slice length K1 to 1
Sequence
<
K0PerBlock
,
NPerBlock
,
I1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
ScaleDataType
,
ScaleDataType
,
decltype
(
scale_grid_desc
),
decltype
(
scale_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
1
,
1
,
1
,
// no effect
1
,
// no effect
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
scale_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
scale_block_desc
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
return
make_tuple
(
scale_block_buf
,
scale_blockwise_copy
);
}
else
{
// Thread-wise copy
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
auto
scale_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ScaleDataType
>
(
scale_block_desc
.
GetElementSpaceSize
());
auto
scale_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
ScaleDataType
,
ScaleDataType
,
decltype
(
scale_grid_desc
),
decltype
(
scale_block_desc
),
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
scale_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
scale_block_buf
,
scale_blockwise_copy
);
}
};
auto
a_block_buf
=
a_block_trait
()[
I0
];
auto
a_block_buf
=
a_block_trait
()[
I0
];
auto
a_blockwise_copy
=
a_block_trait
()[
I1
];
auto
a_blockwise_copy
=
a_block_trait
()[
I1
];
auto
b_block_buf
=
b_block_trait
()[
I0
];
auto
b_block_buf
=
b_block_trait
()[
I0
];
auto
b_blockwise_copy
=
b_block_trait
()[
I1
];
auto
b_blockwise_copy
=
b_block_trait
()[
I1
];
auto
scale_block_buf
=
scale_block_trait
()[
I0
];
auto
scale_blockwise_copy
=
scale_block_trait
()[
I1
];
/*******************************************************************************/
/*******************************************************************************/
// GEMM
// GEMM
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
Blockwise
_fpAintB_
GemmWMMA
<
BlockSize
,
BlockwiseGemmWMMA
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
ADataType
,
//Dequantized
ScaleDataType
,
AccDataType
,
AccDataType
,
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeBWaveDescriptor
(
b_block_desc
)),
decltype
(
MakeBWaveDescriptor
(
b_block_desc
)),
decltype
(
MakeScaleWaveDescriptor
(
scale_block_desc
)),
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
...
@@ -1006,10 +854,7 @@ struct GridwiseFpAintBGemm_Wmma
...
@@ -1006,10 +854,7 @@ struct GridwiseFpAintBGemm_Wmma
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
scale_grid_desc
,
scale_grid_desc
,
scale_block_desc
,
scale_blockwise_copy
,
scale_grid_buf
,
scale_grid_buf
,
scale_block_buf
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
KBlockMainLoop
);
KBlockMainLoop
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
cc0ffeb7
...
@@ -12,7 +12,7 @@ enum struct PipelineVersion
...
@@ -12,7 +12,7 @@ enum struct PipelineVersion
{
{
v1
,
v1
,
v2
,
v2
,
dequant_v1
,
weight_only
,
};
};
template
<
PipelineVersion
PipelineVer
,
template
<
PipelineVersion
PipelineVer
,
...
@@ -37,9 +37,9 @@ constexpr auto GridwiseGemmPipeline_Selector()
...
@@ -37,9 +37,9 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
{
return
GridwiseGemmPipeline_v2
{};
return
GridwiseGemmPipeline_v2
{};
}
}
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
dequant_v1
)
else
if
constexpr
(
PipelineVer
==
PipelineVersion
::
weight_only
)
{
{
return
GridwiseGemmPipeline_v1_
dequant
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
return
GridwiseGemmPipeline_v1_
WeightOnly
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
}
}
else
else
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
cc0ffeb7
...
@@ -551,10 +551,10 @@ struct GridwiseGemmPipeline_v1<1, false, false>
...
@@ -551,10 +551,10 @@ struct GridwiseGemmPipeline_v1<1, false, false>
};
};
template
<
index_t
NumPrefetch
,
bool
AEnableLds
,
bool
BEnableLds
>
template
<
index_t
NumPrefetch
,
bool
AEnableLds
,
bool
BEnableLds
>
struct
GridwiseGemmPipeline_v1_
dequant
;
struct
GridwiseGemmPipeline_v1_
WeightOnly
;
template
<
>
template
<
>
struct
GridwiseGemmPipeline_v1_
dequant
<
1
,
true
,
true
>
struct
GridwiseGemmPipeline_v1_
WeightOnly
<
1
,
true
,
true
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -580,10 +580,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
...
@@ -580,10 +580,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
typename
BBlockBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BBlockTransferStep
,
typename
ScaleGridDesc
,
typename
ScaleGridDesc
,
typename
ScaleBlockDesc
,
typename
ScaleBlockTransfer
,
typename
ScaleGridBuffer
,
typename
ScaleGridBuffer
,
typename
ScaleBlockBuffer
,
typename
BlockwiseGemm
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
...
@@ -599,18 +596,16 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
...
@@ -599,18 +596,16 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
BBlockBuffer
&
b_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
ScaleGridDesc
&
scale_grid_desc
,
const
ScaleGridDesc
&
scale_grid_desc
,
const
ScaleBlockDesc
&
scale_block_desc
,
ScaleBlockTransfer
&
scale_blockwise_copy
,
const
ScaleGridBuffer
&
scale_grid_buf
,
const
ScaleGridBuffer
&
scale_grid_buf
,
ScaleBlockBuffer
&
scale_block_buf
,
const
BlockwiseGemm
&
blockwise_gemm
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
//
preload data into LDS
//
Global Prefetch Stage 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
scale_blockwise_copy
.
RunRead
(
scale_grid_desc
,
scale_grid_buf
);
// Scale read once
b_blockwise_copy
.
RunScaleRead
(
scale_grid_desc
,
scale_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
...
@@ -619,8 +614,8 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
...
@@ -619,8 +614,8 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// Dequantization fused in blockwise_copy
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
scale_blockwise_copy
.
RunWrite
(
scale_block_desc
,
scale_block_buf
);
// main body
// main body
if
constexpr
(
HasMainLoop
)
if
constexpr
(
HasMainLoop
)
...
@@ -635,7 +630,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
...
@@ -635,7 +630,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
block_sync_lds
();
...
@@ -653,118 +648,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
...
@@ -653,118 +648,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
{
{
block_sync_lds
();
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
template
<
>
struct
GridwiseGemmPipeline_v1_dequant
<
1
,
true
,
false
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
ScaleGridDesc
,
typename
ScaleBlockDesc
,
typename
ScaleBlockTransfer
,
typename
ScaleGridBuffer
,
typename
ScaleBlockBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
ScaleGridDesc
&
scale_grid_desc
,
const
ScaleBlockDesc
&
scale_block_desc
,
ScaleBlockTransfer
&
scale_blockwise_copy
,
const
ScaleGridBuffer
&
scale_grid_buf
,
ScaleBlockBuffer
&
scale_block_buf
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
b_block_buf_switch
=
b_block_buf
;
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf
);
scale_blockwise_copy
.
Run
(
scale_grid_desc
,
scale_grid_buf
,
scale_block_desc
,
b_block_origin_idx
,
scale_block_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_origin_idx
,
b_block_buf_switch
);
block_sync_lds
();
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_block_buf
=
b_block_buf_switch
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
scale_block_buf
,
c_thread_buf
);
block_sync_lds
();
}
}
}
}
};
};
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp
0 → 100644
View file @
cc0ffeb7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
namespace
ck
{
namespace
detail
{
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template
<
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
>
struct
lambda_scalar_per_access_for_src_and_dst_idle
{
__host__
__device__
constexpr
auto
operator
()(
index_t
i
)
const
{
if
(
i
==
SrcVectorDim
&&
i
==
DstVectorDim
)
{
return
math
::
lcm
(
SrcScalarPerVector
,
DstScalarPerVector
);
}
else
if
(
i
==
SrcVectorDim
)
{
return
SrcScalarPerVector
;
}
else
if
(
i
==
DstVectorDim
)
{
return
DstScalarPerVector
;
}
else
{
return
1
;
}
}
};
}
// namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
// 4. Use thread buffer
// 5. Dequantization happened between read and write.
template
<
typename
SliceLengths
,
typename
ScaleSliceLengths
,
typename
SrcElementwiseOperation
,
typename
ScaleElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
SrcData
,
typename
ScaleData
,
typename
DstData
,
typename
SrcDesc
,
typename
ScaleDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
ScaleScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
ScaleScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
// RunRead(), will be fused with MoveSrcSliceWindow to
// save addr computation
bool
DstResetCoordinateAfterRun
,
// control whether to move back dst coordinate after each
// RunWrite(), will be fused with MoveDstSliceWindow to
// save addr computation
index_t
NumThreadScratch
=
1
>
struct
ThreadwiseTensorSliceTransfer_v3r1_dequant
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
ScaleCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1_dequant
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_slice_origin
,
const
ScaleElementwiseOperation
&
scale_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
,
const
DstElementwiseOperation
&
dst_element_op
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin
)),
scale_coord_
(
make_tensor_coordinate
(
scale_desc
,
scale_slice_origin
)),
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin
)),
src_element_op_
(
src_element_op
),
scale_element_op_
(
scale_element_op
),
dst_element_op_
(
dst_element_op
)
{
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
__device__
void
SetScaleSliceOrigin
(
const
ScaleDesc
&
scale_desc
,
const
Index
&
scale_slice_origin_idx
)
{
scale_coord_
=
make_tensor_coordinate
(
scale_desc
,
scale_slice_origin_idx
);
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
static_assert
(
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
SrcBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
,
"wrong! SrcBuffer and SrcData data type are inconsistent"
);
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// make forward steps
const
auto
src_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
src_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_src_access_lengths
)
>
{}([
&
](
auto
ordered_src_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_src_access_idx
[
i
]
:
ordered_src_access_lengths
[
i
]
-
1
-
ordered_src_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
src_dim_access_order
)
*
src_scalar_per_access
;
}();
constexpr
auto
src_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
src_data_idx
[
i
]
>
{};
},
Number
<
src_data_idx
.
Size
()
>
{});
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
// copy data from src_buf into src_vector_container
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
)};
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_
(
thread_scratch_id
)
.
template
SetAsType
<
src_vector_t
>(
src_data_idx_seq
,
src_vector_container
.
template
AsType
<
src_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_src_access_idx
[
i
]
<
ordered_src_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_src_access_idx
[
j
]
==
ordered_src_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move src coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
src_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
src_dim_access_order
[
i
]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_reset_step
=
make_tensor_coordinate_step
(
src_desc
,
GetSrcCoordinateResetStep
());
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_reset_step
);
}
}
template
<
typename
ScaleBuffer
>
__device__
void
RunScaleRead
(
const
ScaleDesc
&
scale_desc
,
const
ScaleBuffer
&
scale_buf
)
{
static_assert
(
ScaleBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
ScaleBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
ScaleBuffer
::
type
>
,
remove_cvref_t
<
ScaleData
>>::
value
,
"wrong! ScaleBuffer and ScaleData data type are inconsistent"
);
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
scale_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
ScaleScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
scale_access_lengths
=
SliceLengths
{}
/
scale_scalar_per_access
;
constexpr
auto
scale_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_scale_access_lengths
=
container_reorder_given_new2old
(
scale_access_lengths
,
scale_dim_access_order
);
// make forward steps
const
auto
scale_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
scale_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
scale_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
scale_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
scale_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
scale_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_scale_access_lengths
)
>
{}([
&
](
auto
ordered_scale_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_scale_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_scale_access_lengths
[
j
]
+
ordered_scale_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate scale data index
constexpr
auto
scale_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_scale_access_idx
[
i
]
:
ordered_scale_access_lengths
[
i
]
-
1
-
ordered_scale_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
scale_dim_access_order
)
*
scale_scalar_per_access
;
}();
constexpr
auto
scale_data_idx_seq
=
generate_sequence_v2
([
&
](
auto
i
)
{
return
Number
<
scale_data_idx
[
i
]
>
{};
},
Number
<
scale_data_idx
.
Size
()
>
{});
const
bool
is_scale_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
scale_desc
,
scale_coord_
);
using
scale_vector_type
=
vector_type_maker_t
<
ScaleData
,
ScaleScalarPerVector
>
;
using
scale_vector_t
=
typename
scale_vector_type
::
type
;
// copy data from scale_buf into scale_vector_container
auto
scale_vector_container
=
scale_vector_type
{
scale_buf
.
template
Get
<
scale_vector_t
>(
scale_coord_
.
GetOffset
(),
is_scale_valid
)};
// copy data from scale_vector_container into scale_thread_scratch_
scale_thread_scratch_
.
template
SetAsType
<
scale_vector_t
>(
scale_data_idx_seq
,
scale_vector_container
.
template
AsType
<
scale_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_scale_access_idx
[
i
]
<
ordered_scale_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_scale_access_idx
[
j
]
==
ordered_scale_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move scale coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
scale_desc
,
scale_coord_
,
scale_forward_steps
[
scale_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
scale_desc
,
scale_coord_
,
scale_backward_steps
[
scale_dim_access_order
[
i
]]);
}
}
});
});
// don't need to move scale coordinate back to slice origin
/*
if constexpr(SrcResetCoordinateAfterRun)
{
const auto scale_reset_step =
make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep());
move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step);
}
*/
}
template
<
index_t
ThreadScratchId
>
__device__
void
TransferDataFromSrcThreadScratchToDstThreadScratch
(
Number
<
ThreadScratchId
>
thread_scratch_id
)
{
#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
dst_thread_scratch_
(
idx
)
=
type_convert
<
DstData
>
(
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
]);
});
#else
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
constexpr
index_t
num_src_vector
=
Number
<
DstScalarPerVector
>
{};
constexpr
index_t
num_dst_vector
=
Number
<
SrcScalarPerVector
>
{};
// Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose
// TODO: make this logic generic for all scenario
static_assert
(
SrcVectorDim
!=
DstVectorDim
,
"wrong"
);
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst_idle
<
SrcVectorDim
,
SrcScalarPerVector
,
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_idx
)
{
constexpr
auto
data_idx
=
access_idx
*
scalar_per_access
;
constexpr
auto
data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
data_idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
using
src_vector_t
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
// get DstScalarPerVector # of read-only references to src vectors from
// src_thread_scratch_
const
auto
src_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
const
src_vector_t
&
{
// i increment corresponds to movement in DstVectorDim
return
src_thread_scratch_tuple_
[
thread_scratch_id
].
GetVectorTypeReference
(
data_idx_seq
+
i
*
dst_scalar_step_in_vector
);
},
Number
<
num_src_vector
>
{});
// get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_
auto
dst_vector_refs
=
generate_tie
(
[
&
](
auto
i
)
->
dst_vector_t
&
{
// i increment corresponds to movement in SrcVectorDim
return
dst_thread_scratch_
.
GetVectorTypeReference
(
data_idx_seq
+
i
*
src_scalar_step_in_vector
);
},
Number
<
num_dst_vector
>
{});
// do data transpose
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
// Do fast numeric convert
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access_for_src_and_dst_idle
<
SrcVectorDim
,
SrcScalarPerVector
,
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
src_converted_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
src_converted_vector_t
=
typename
src_converted_vector_type
::
type
;
// Vector-wise type convert
static_ford
<
decltype
(
access_lengths
)
>
{}([
&
](
auto
access_idx
)
{
auto
src_vector_container
=
src_vector_type
{
src_thread_scratch_tuple_
[
thread_scratch_id
].
template
GetAsType
<
src_vector_t
>(
access_idx
)};
auto
src_converted_vector_container
=
src_converted_vector_type
{
fast_numeric_converter
(
src_vector_container
)};
src_converted_thread_scratch_
.
template
SetAsType
<
src_converted_vector_t
>(
access_idx
,
src_converted_vector_container
.
template
AsType
<
src_converted_vector_t
>()[
I0
]);
});
// Element-scale operation, expect packed multiplication
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
DstData
dst_v
;
constexpr
auto
scale_idx
=
Sequence
<
I0
,
idx
.
At
(
1
),
I0
>
{};
// printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<const uint16_t*>(&scale_thread_scratch_[scale_idx])));
src_element_op_
(
dst_v
,
src_converted_thread_scratch_
[
idx
]
*
scale_thread_scratch_
[
scale_idx
]);
dst_thread_scratch_
(
idx
)
=
dst_v
;
});
#endif
}
template
<
typename
DstBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
// if there is transpose, it's done here
// TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch
(
thread_scratch_id
);
static_assert
(
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
or
DstBuffer
::
GetAddressSpace
()
==
AddressSpaceEnum
::
Lds
,
"wrong!"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
// src scalar per access on each dim
// TODO: don't use this
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// make forward steps
const
auto
dst_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
dst_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_dst_access_idx
[
i
]
:
ordered_dst_access_lengths
[
i
]
-
1
-
ordered_dst_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dst_dim_access_order
)
*
dst_scalar_per_access
;
}();
constexpr
auto
dst_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
dst_data_idx
[
i
]
>
{};
},
Number
<
dst_data_idx
.
Size
()
>
{});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
// copy data from dst_thread_scratch_ into dst_vector_container
auto
dst_vector_container
=
dst_vector_type
{
dst_thread_scratch_
.
template
GetAsType
<
dst_vector_t
>(
dst_data_idx_seq
)};
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
DstData
dst_v
;
// apply DstElementwiseOperation
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
dst_v
;
});
// copy data from dst_vector_container to dst_buf
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
ordered_dst_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_dst_access_idx
[
j
]
==
ordered_dst_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
();
// move dst coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dst_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dst_dim_access_order
[
i
]]);
}
}
});
});
// move dst coordinate back to slice origin (or not)
if
constexpr
(
DstResetCoordinateAfterRun
)
{
const
auto
dst_reset_step
=
make_tensor_coordinate_step
(
dst_desc
,
GetDstCoordinateResetStep
());
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_reset_step
);
}
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_src_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
src_dim_access_order
)
*
src_scalar_per_access
;
}();
//
constexpr
auto
reset_src_data_step
=
[
&
]()
{
Index
reset_src_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_src_data_step_
(
i
)
=
-
src_data_idx
[
i
];
});
return
reset_src_data_step_
;
}();
return
reset_src_data_step
;
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_dst_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
dst_dim_access_order
)
*
dst_scalar_per_access
;
}();
//
constexpr
auto
reset_dst_data_step
=
[
&
]()
{
Index
reset_dst_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_dst_data_step_
(
i
)
=
-
dst_data_idx
[
i
];
});
return
reset_dst_data_step_
;
}();
return
reset_dst_data_step
;
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
src_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_tensor_coordinate_step
(
dst_desc
,
adjusted_step_idx
);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
adjusted_step
);
}
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
src_access_lengths
),
Number
<
SrcScalarPerVector
>
{});
// 1st stage of transforms
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
src_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
src_access_lengths_and_vector_length
[
i
],
src_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
src_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetScaleThreadScratchDescriptor
()
{
constexpr
auto
scale_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
ScaleScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
scale_access_lengths
=
SliceLengths
{}
/
scale_scalar_per_access
;
constexpr
auto
scale_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
scale_access_lengths
),
Number
<
ScaleScalarPerVector
>
{});
// 1st stage of transforms
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
scale_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
scale_access_lengths_and_vector_length
[
i
],
scale_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
scale_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
__device__
static
constexpr
auto
GetDstThreadScratchDescriptor
()
{
// 1st stage of transforms
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths_and_vector_length
=
container_push_back
(
sequence_to_tuple_of_number
(
dst_access_lengths
),
Number
<
DstScalarPerVector
>
{});
constexpr
auto
desc0
=
make_naive_tensor_descriptor_packed
(
dst_access_lengths_and_vector_length
);
// 2nd stage of transforms
constexpr
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
make_merge_transform_v3_division_mod
(
make_tuple
(
dst_access_lengths_and_vector_length
[
i
],
dst_access_lengths_and_vector_length
[
Number
<
nDim
>
{}]));
}
else
{
return
make_pass_through_transform
(
dst_access_lengths_and_vector_length
[
i
]);
}
},
Number
<
nDim
>
{});
constexpr
auto
low_dim_idss
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
==
DstVectorDim
)
{
return
Sequence
<
i
.
value
,
nDim
>
{};
}
else
{
return
Sequence
<
i
.
value
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
up_dim_idss
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
nDim
>
{});
return
transform_tensor_descriptor
(
desc0
,
transforms
,
low_dim_idss
,
up_dim_idss
);
}
private:
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
static
constexpr
auto
scale_thread_scratch_desc_
=
decltype
(
GetScaleThreadScratchDescriptor
()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
/*
template <bool kLastDim>
struct ScaleThreadScratchDesc{};
*/
// Registers, contain raw data loaded from global buffer
using
SrcThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
SrcData
,
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
;
// Registers, contain fast converted data
using
SrcThreadConvertedScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
SrcScalarPerVector
,
decltype
(
src_thread_scratch_desc_
),
true
>
;
// Registers, contain scale data
using
ScaleThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
ScaleData
,
ScaleScalarPerVector
,
decltype
(
scale_thread_scratch_desc_
),
true
>
;
// Registers, contain dequantized data
using
DstThreadScratch
=
StaticTensorTupleOfVectorBuffer
<
AddressSpaceEnum
::
Vgpr
,
DstData
,
DstScalarPerVector
,
decltype
(
dst_thread_scratch_desc_
),
true
>
;
using
FastTypeConverter
=
tensor_operation
::
element_wise
::
FastNumericArrayConverter
<
SrcData
,
DstData
,
SrcScalarPerVector
>
;
StaticallyIndexedArray
<
SrcThreadScratch
,
NumThreadScratch
>
src_thread_scratch_tuple_
;
SrcThreadConvertedScratch
src_converted_thread_scratch_
;
ScaleThreadScratch
scale_thread_scratch_
;
DstThreadScratch
dst_thread_scratch_
;
FastTypeConverter
fast_numeric_converter
;
SrcCoord
src_coord_
;
ScaleCoord
scale_coord_
;
DstCoord
dst_coord_
;
const
SrcElementwiseOperation
src_element_op_
;
const
ScaleElementwiseOperation
scale_element_op_
;
const
DstElementwiseOperation
dst_element_op_
;
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment