Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
522b7aee
Commit
522b7aee
authored
Jan 30, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
ff936fd6
84832fc4
Changes
130
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2911 additions
and
443 deletions
+2911
-443
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
...or_operation/gpu/device/impl/device_contraction_utils.hpp
+6
-4
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
+306
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+5
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+302
-1
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
.../grid/gridwise_elementwise_layernorm_welford_variance.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
+1153
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+30
-0
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+1
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+65
-0
include/ck/utility/is_known_at_compile_time.hpp
include/ck/utility/is_known_at_compile_time.hpp
+7
-1
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+11
-0
include/ck/wrapper/layout.hpp
include/ck/wrapper/layout.hpp
+153
-115
include/ck/wrapper/operations/copy.hpp
include/ck/wrapper/operations/copy.hpp
+175
-0
include/ck/wrapper/tensor.hpp
include/ck/wrapper/tensor.hpp
+306
-197
include/ck/wrapper/utils/layout_utils.hpp
include/ck/wrapper/utils/layout_utils.hpp
+150
-54
include/ck/wrapper/utils/tensor_partition.hpp
include/ck/wrapper/utils/tensor_partition.hpp
+195
-0
include/ck/wrapper/utils/tensor_utils.hpp
include/ck/wrapper/utils/tensor_utils.hpp
+37
-68
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
...erence_tensor_operation/cpu/reference_column_to_image.hpp
+2
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
...eference_tensor_operation/cpu/reference_conv_bwd_data.hpp
+3
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
...erence_tensor_operation/cpu/reference_conv_bwd_weight.hpp
+2
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
View file @
522b7aee
...
...
@@ -35,15 +35,17 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
if
(
lengths
.
size
()
!=
NumDim1
+
NumDim2
)
{
std
::
ostringstream
err
;
err
<<
"Incorrect number of lengths in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
err
<<
"Incorrect number of lengths in "
<<
"device_contraction_utils.hpp"
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
strides
.
size
()
!=
NumDim1
+
NumDim2
)
{
std
::
ostringstream
err
;
err
<<
"Incorrect number of strides in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
err
<<
"Incorrect number of strides in "
<<
"device_contraction_utils.hpp"
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp
0 → 100644
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// version currently has compiler issues with register spill which further causes validation
// failures.
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGemm_Xdl_CShuffleV2
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemm_Xdl_CShuffleV2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v2
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
arg
.
Print
();
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
);
float
ave_time
=
0
;
const
auto
K
=
GridwiseGemm
::
CalculateAK0
(
arg
.
K
)
*
AK1
;
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K
)
==
3
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
true
,
2
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
if
((
arg
.
K
%
AK1
!=
0
||
arg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffleV2"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
522b7aee
...
...
@@ -174,6 +174,11 @@ struct PassThrough
{
y
=
x
;
}
template
<
>
__host__
__device__
void
operator
()
<
int4_t
,
int
>
(
int4_t
&
y
,
const
int
&
x
)
const
{
y
=
type_convert
<
int4_t
>
(
x
);
}
#endif
template
<
>
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -134,6 +134,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template
<
typename
CGridDesc_M_N
>
...
...
@@ -252,6 +257,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt<MPerBlo
BlockToCTileMap_M00_N0_M01Adapt
;
};
// Rows of column-vectors
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
=
void
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
;
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
GroupNum
,
MPerBlock
,
NPerBlock
,
void
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
const
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_Grouped_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_);
}
#endif
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_Grouped_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
;
}
template
<
typename
CGridDesc_M_N
>
__host__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
M_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
const
auto
group_size
=
math
::
integer_divide_ceil
(
M0
*
N0
,
GroupNum
);
auto
group_id
=
block_1d_id
%
GroupNum
;
auto
remap_block_1d_id
=
group_id
*
group_size
+
block_1d_id
/
GroupNum
;
index_t
idx_N0
=
remap_block_1d_id
%
N0
;
index_t
idx_M0
=
remap_block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
index_t
idx_M00
=
idx_M0
/
M01_
;
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t
M_
;
index_t
N_
;
index_t
M01_
;
};
// keep the redundant type argument for backward compatibility
template
<
index_t
GroupNum
,
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_Grouped_M00_N0_M01Adapt
:
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
GroupNum
,
MPerBlock
,
NPerBlock
,
void
>
{
using
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
GroupNum
,
MPerBlock
,
NPerBlock
,
void
>::
BlockToCTileMap_Grouped_M00_N0_M01Adapt
;
};
// columns of row-vectors
// This C-tile map dynamically adjusts N01 when C-tile index is out of range
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
=
void
>
struct
BlockToCTileMap_N00_M0_N01Adapt
;
template
<
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_N00_M0_N01Adapt
<
MPerBlock
,
NPerBlock
,
void
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
const
BlockToCTileMap_N00_M0_N01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
BlockToCTileMap_N00_M0_N01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
&
operator
=
(
const
BlockToCTileMap_N00_M0_N01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
&
operator
=
(
BlockToCTileMap_N00_M0_N01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
index_t
M
,
index_t
N
,
index_t
N01
=
8
)
:
M_
(
M
),
N_
(
N
),
N01_
(
N01
)
{
#if 0
if(get_thread_global_1d_id()==0){
printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_);
}
#endif
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_N00_M0_N01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
N01
=
8
)
:
BlockToCTileMap_N00_M0_N01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
N01
)
{
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
;
}
template
<
typename
CGridDesc_M_N
>
__host__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
M_
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
index_t
idx_M0
=
block_1d_id
%
M0
;
index_t
idx_N0
=
block_1d_id
/
M0
;
const
auto
N01_adapt
=
(
idx_N0
<
N0
-
N0
%
N01_
)
?
N01_
:
N0
%
N01_
;
index_t
idx_N00
=
idx_N0
/
N01_
;
index_t
idx_N01
=
idx_N0
%
N01_
;
index_t
idx_M0_N01_local
=
idx_M0
+
idx_N01
*
M0
;
/**
* idxN0
*
* |< mtx N >|
*
* |<---N01--->|
* - |-----------|-----------|-----------|-----|-----|-
* ^ | 0 ----------> 1 | | | |
* | | / | | | | M_0 MPerBlock
* | / | | | |
* |------/----------------|-----------|-----|-----|-
* | | | | | | |
* idxM0 | V | | | | | M_1 MPerBlock
* | 2 ----------> 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | blockid | | | |
* | | 5 | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* Example:
* assume:
* N0 = 5
* M0 = 4
* block_1d_id = 5
* N01 = 2
*
* idx_M0 = 1
* idx_N0 = 1
* N01_adapt = 2
* idx_N00 = 0
* idx_N01 = 1
* idx_M0_N01_local = 5
* output {2, 1}
*/
return
make_tuple
(
idx_M0_N01_local
/
N01_adapt
,
idx_M0_N01_local
%
N01_adapt
+
idx_N00
*
N01_
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t
M_
;
index_t
N_
;
index_t
N01_
;
};
// 2D slices of column-vectors in 3D space
// This C-tile map dynamically adjusts M01 when C-tile index is out of range
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
View file @
522b7aee
...
...
@@ -119,7 +119,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
InDataTypePointerTuple
p_in_global_tuple
,
XDataType
*
const
__restrict__
p_x_lds
,
XDataType
*
const
__restrict__
p_x_lds
_
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
...
...
@@ -149,7 +149,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
x_lds_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_x_lds
,
x_grid_desc_m_k
.
GetElementSpaceSize
()
/
grid_size
);
p_x_lds
_
,
x_grid_desc_m_k
.
GetElementSpaceSize
()
/
grid_size
);
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
0 → 100644
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
index_t
TailNum
=
3
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
1
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v2
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
1
)
#endif
kernel_gemm_xdl_cshuffle_v2
(
const
FloatA
*
p_a_grid
,
const
FloatB
*
p_b_grid
,
FloatC
*
p_c_grid
,
typename
GridwiseGemm
::
Problem
problem
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_0
,
p_shared_1
,
problem
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
problem
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
FloatA
,
typename
FloatB
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeTypeA
=
FloatC
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
GridwiseGemm_xdl_cshuffle_v2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
1
);
}
__host__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_divide_ceil
(
M
,
MPerBlock
)
*
MPerBlock
;
}
__host__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
)
*
NPerBlock
;
}
__host__
static
auto
CalculateKPadded
(
index_t
K
)
{
return
math
::
integer_divide_ceil
(
K
,
KPerBlock
)
*
KPerBlock
;
}
__host__
static
auto
CalculateAK0
(
index_t
K
)
{
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
return
CalculateKPadded
(
K
)
/
AK1Value
;
}
else
{
return
K
/
AK1Value
;
}
}
__host__
static
auto
CalculateBK0
(
index_t
K
)
{
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
return
CalculateKPadded
(
K
)
/
BK1Value
;
}
else
{
return
K
/
BK1Value
;
}
}
__host__
static
auto
CalculateMBlock
(
index_t
M
)
{
return
math
::
integer_divide_floor
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNBlock
(
index_t
N
)
{
return
math
::
integer_divide_floor
(
N
,
NPerBlock
);
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
StrideB
,
I1
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
N
,
NPad
-
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeAMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
struct
Problem
{
__host__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
KPadded
{
CalculateKPadded
(
K_
)},
AK0
{
CalculateAK0
(
K_
)},
BK0
{
CalculateBK0
(
K_
)},
MBlock
{
CalculateMBlock
(
M_
)},
NBlock
{
CalculateNBlock
(
N_
)}
{
}
__host__
void
Print
()
const
{
std
::
cout
<<
"problem {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
", "
<<
"MBlock: "
<<
MBlock
<<
", "
<<
"NBlock: "
<<
NBlock
<<
"}"
<<
std
::
endl
;
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
index_t
MBlock
;
index_t
NBlock
;
};
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
Problem
{
__host__
Argument
(
const
FloatA
*
p_a_grid_
,
const
FloatB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
{
}
const
FloatA
*
p_a_grid
;
const
FloatB
*
p_b_grid
;
FloatC
*
p_c_grid
;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1Number
,
AK1Number
,
I1
));
}
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1Number
,
BK1Number
,
I1
));
}
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
ComputeTypeA
)
+
b_block_space_size_aligned
*
sizeof
(
ComputeTypeB
)),
c_block_size
*
sizeof
(
FloatCShuffle
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
static
constexpr
bool
CheckValidity
(
const
Problem
&
problem
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
problem
.
M
%
MPerBlock
==
0
))
{
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
problem
.
N
%
NPerBlock
==
0
))
{
return
false
;
}
}
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
{
if
(
!
(
CalculateKPadded
(
problem
.
K
)
%
AK1Value
==
0
)
||
!
(
CalculateKPadded
(
problem
.
K
)
%
BK1Value
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
problem
.
K
%
AK1Value
==
0
)
||
!
(
problem
.
K
%
BK1Value
==
0
))
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
problem
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
problem
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
problem
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
problem
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
if
(
problem
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
if
(
problem
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
(
CalculateAK0
(
problem
.
K
)
*
AK1Value
)
/
KPerBlock
;
if
(
num_k_loop
<
4
)
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
num_loop
>
3
;
}
__host__
static
constexpr
index_t
CalculateKBlockLoopTailNum
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
if
(
num_loop
%
2
==
1
)
return
3
;
else
return
2
;
}
template
<
typename
CGridDesc
>
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using
Block2CTileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
template
<
bool
HasMainKBlockLoop
,
index_t
TailNum
=
3
>
__device__
static
void
Run
(
const
FloatA
*
p_a_grid
,
const
FloatB
*
p_b_grid
,
FloatC
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
)
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMap
{
problem
.
M
,
problem
.
N
,
4
};
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
#if 0
if(threadIdx.x == 0){
printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n",
get_block_1d_id(),
block_work_idx[I0],
block_work_idx[I1],
__smid()>>6 & 0xf,
__smid()>>4 & 0x3,
__smid() & 0xf);
}
#endif
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatA
,
ComputeTypeA
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatB
,
ComputeTypeB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
// BlockSize,
// ComputeType,
// FloatGemmAcc,
// decltype(a_block_desc_ak0_m_ak1),
// decltype(b_block_desc_bk0_n_bk1),
// MPerXdl,
// NPerXdl,
// MXdlPerWave,
// NXdlPerWave,
// KPack,
// LoopSched>();
auto
blockwise_gemm_pipeline
=
BlockwiseGemmXdlops_pipeline_v4
<
BlockSize
,
ComputeTypeA
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeAMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeBMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
{};
// TransposeC
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeTypeA
*>
(
p_shared_0
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeTypeB
*>
(
p_shared_0
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeTypeA
*>
(
p_shared_1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeTypeB
*>
(
p_shared_1
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
auto
b_block_bufs
=
make_tuple
(
b_block_buf_ping
,
b_block_buf_pong
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// gridwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
GridwiseGemmPipe
>
);
// const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_bufs
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_bufs
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared_0
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
522b7aee
...
...
@@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
)
{
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
...
...
@@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
)
{
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
522b7aee
...
...
@@ -328,7 +328,7 @@ struct WmmaSelector
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
...
...
include/ck/utility/data_type.hpp
View file @
522b7aee
...
...
@@ -189,6 +189,7 @@ struct vector_type<T, 1>
}
};
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
{
...
...
@@ -221,6 +222,10 @@ struct vector_type<T, 2>
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -236,6 +241,10 @@ struct vector_type<T, 2>
{
return
data_
.
d2x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -278,6 +287,10 @@ struct vector_type<T, 4>
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -298,6 +311,10 @@ struct vector_type<T, 4>
{
return
data_
.
d4x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -347,6 +364,10 @@ struct vector_type<T, 8>
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -372,6 +393,10 @@ struct vector_type<T, 8>
{
return
data_
.
d8x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -428,6 +453,10 @@ struct vector_type<T, 16>
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -458,6 +487,10 @@ struct vector_type<T, 16>
{
return
data_
.
d16x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -520,6 +553,10 @@ struct vector_type<T, 32>
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -554,6 +591,10 @@ struct vector_type<T, 32>
{
return
data_
.
d32x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -623,6 +664,10 @@ struct vector_type<T, 64>
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -662,6 +707,10 @@ struct vector_type<T, 64>
{
return
data_
.
d64x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -737,6 +786,10 @@ struct vector_type<T, 128>
{
return
data_
.
d128x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -780,6 +833,10 @@ struct vector_type<T, 128>
{
return
data_
.
d128x1_
;
}
else
{
return
err
;
}
}
};
...
...
@@ -861,6 +918,10 @@ struct vector_type<T, 256>
{
return
data_
.
d256x1_
;
}
else
{
return
err
;
}
}
template
<
typename
X
>
...
...
@@ -908,6 +969,10 @@ struct vector_type<T, 256>
{
return
data_
.
d256x1_
;
}
else
{
return
err
;
}
}
};
...
...
include/ck/utility/is_known_at_compile_time.hpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -19,6 +19,12 @@ struct is_known_at_compile_time<index_t>
static
constexpr
bool
value
=
false
;
};
template
<
>
struct
is_known_at_compile_time
<
unsigned
int
>
{
static
constexpr
bool
value
=
false
;
};
template
<
>
struct
is_known_at_compile_time
<
long_index_t
>
{
...
...
include/ck/utility/tuple_helper.hpp
View file @
522b7aee
...
...
@@ -178,4 +178,15 @@ __host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
return
math
::
max
(
TupleDepth
<
depth
+
1
>
(
Ts
{})...);
}
template
<
index_t
from
,
index_t
to
,
typename
...
Ts
>
__host__
__device__
constexpr
auto
TupleSlice
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
Idx
=
Number
<
from
+
i
>
;
return
tuple
.
At
(
Idx
{});
},
Number
<
to
-
from
>
{});
}
}
// namespace ck
include/ck/wrapper/layout.hpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -14,24 +14,28 @@ namespace wrapper {
* \tparam Shape Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). It is possible to pass nested shapes
* (e.g. ((4, 2), 2)), nested dimensions are merged.
* \tparam Strides Tuple of Number<> (for compile-time layout) or index_t
* (dynamic layout). Stride tuple should be nested if shape tuple is
* nested.
* \tparam UnrolledDescriptorType Tensor descriptor for unnested shape dims.
*/
template
<
typename
Shape
,
typename
Strides
>
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
struct
Layout
{
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
// Generate default idxs tuple (idx with all merged nested shapes)
/**
* \brief Generate default indices tuple (idx with all merged nested shapes)
*
* \param shape Shape to align.
* \return Multi idx tuple with zeros.
*/
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateDefaultIdxsTuple
(
const
Tuple
<
Ts
...
>&
)
__host__
__device__
constexpr
static
auto
GenerateDefaultIdxsTuple
([[
maybe_unused
]]
const
Tuple
<
Ts
...
>&
shape
)
{
return
generate_tuple
(
[
&
](
auto
)
{
if
constexpr
(
!
Flatten
DescriptorType
::
IsKnownAtCompileTime
())
if
constexpr
(
!
remove_cvref_t
<
Unrolled
DescriptorType
>
::
IsKnownAtCompileTime
())
{
// runtime layout
return
index_t
(
0
);
...
...
@@ -45,32 +49,18 @@ struct Layout
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
}
// Generate packed (column-major) strides if not passed
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateColumnMajorPackedStrides
(
const
Tuple
<
Ts
...
>&
shape
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
.
value
==
0
)
{
return
I1
;
}
else
{
return
TupleReduce
<
I0
.
value
,
i
.
value
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_shape
);
}
},
Number
<
decltype
(
unrolled_shape
)
::
Size
()
>
{});
}
// Generate LowerDims in Compile-time for MergeTrasform using passed Type
// If element of Tuple<Ts...> is also tuple, then merge (generate sequence for merge)
// If tuple is element, then pass through (sequence with one element)
/**
* \brief Generate lower dims in compile-time for the Merge transform using
* provided type. If element of nested Tuple<Ts...> is also a tuple, then
* merge (generate sequence for merge). If tuple is element, then pass
* through (sequence with one element).
*
* \param shape Shape to align.
* \return LowerDims for MergeTrasform.
*/
template
<
typename
Idx
,
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateLowerDim
(
const
Tuple
<
Ts
...
>&
)
__host__
__device__
constexpr
static
auto
GenerateLowerDim
([[
maybe_unused
]]
const
Tuple
<
Ts
...
>&
shape
)
{
if
constexpr
(
Idx
::
value
==
0
)
{
...
...
@@ -110,11 +100,17 @@ struct Layout
}
}
// Iterate over nested tuples in shape
// Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
// Example idx: (1, 1), 1, 1
// Example shape: (2, (2, 2)), 2, (2, 2)
// Unrolled shape: 2, (2, 2), 2, (2, 2)
/**
* \brief Iterate over the nested tuples in the shape.
* Unroll nested tuples to align Tuple<ShapeDims...> to Tuple<IdxDims...>
* Example idx: (1, 1), 1, 1
* Example shape: (2, (2, 2)), 2, (2, 2)
* Unrolled shape: 2, (2, 2), 2, (2, 2)
*
* \param shape Layout shape.
* \param idx Idx to align.
* \return Algined shape.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
AlignShapeToIdx
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
)
...
...
@@ -149,6 +145,13 @@ struct Layout
}
}
/**
* \brief Merge descriptor to 1D.
*
* \param shape Layout shape.
* \param desc Descriptor to merge.
* \return 1D descriptor.
*/
template
<
typename
...
ShapeDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
MakeMerge1d
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
DescriptorToMerge
&
desc
)
...
...
@@ -160,18 +163,41 @@ struct Layout
const
auto
lower_dims
=
make_tuple
(
MergeElemsSequence
::
Reverse
());
const
auto
upper_dims
=
make_tuple
(
Sequence
<
0
>
{});
// Merge to 1d
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because it doesn't use
// memcpy.
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform_v1_carry_check
(
merge_elems
)),
lower_dims
,
upper_dims
);
}
}
// Merge nested shape dims when corresponding index is also nested.
// Input desc shape: 2, 2, 2, 2, 2, 2
// Example idx: 1, 1, 1, 1
// Example shape: 2, (2, 2), 2, (2, 2)
// Merged shape: 2, 4, 2, 4
/**
* \brief Merge nested shape dims when corresponding index is also merged.
* Input desc shape: 2, 2, 2, 2, 2, 2
* Example idx: 1, 1, 1, (1, 1)
* Example shape: 2, (2, 2), 2, (2, 2)
* Merged shape: 2, 4, 2, 2, 2
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param desc Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
,
typename
DescriptorToMerge
>
__host__
__device__
constexpr
static
auto
CreateMergedDescriptor
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
,
DescriptorToMerge
&
desc
)
__host__
__device__
constexpr
static
auto
CreateMergedDescriptor
(
const
Tuple
<
ShapeDims
...
>&
shape
,
[[
maybe_unused
]]
const
Tuple
<
IdxDims
...
>&
idxs
,
DescriptorToMerge
&
desc
)
{
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
...
...
@@ -183,9 +209,19 @@ struct Layout
// If shape element is tuple and idx element is Number, then merge
// Unroll and reverse tuple to traverse column-major
const
auto
merge_elems
=
TupleReverse
(
UnrollNestedTuple
(
shape
.
At
(
i
)));
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
return
make_merge_transform
(
merge_elems
);
}
else
{
// If the descriptor is known at the compilation time,
// use `make_merge_transform_v1_carry_check` because
// it doesn't use memcpy.
return
make_merge_transform_v1_carry_check
(
merge_elems
);
}
}
else
{
// If shape element is integer and idx element is tuple, passed idx is wrong
static_assert
(
...
...
@@ -207,33 +243,24 @@ struct Layout
return
transform_tensor_descriptor
(
desc
,
transforms
,
lower_dims
,
upper_dims
);
}
template
<
typename
LayoutShape
,
typename
LayoutStrides
>
__host__
__device__
static
auto
MakeFlattenDescriptor
(
const
LayoutShape
&
shape
,
const
LayoutStrides
&
strides
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
const
auto
unrolled_strides
=
UnrollNestedTuple
(
strides
);
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
"Size of strides and shape are not consistent."
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
}
// If the stride is not passed, you can infer it from `GenerateColumnMajorPackedStrides`.
using
DeducedStrides
=
std
::
conditional_t
<
is_same_v
<
Strides
,
Tuple
<>>
,
remove_cvref_t
<
decltype
(
GenerateColumnMajorPackedStrides
(
Shape
{}))
>
,
Strides
>
;
using
FlattenDescriptorType
=
remove_cvref_t
<
decltype
(
MakeFlattenDescriptor
(
Shape
{},
DeducedStrides
{}))
>
;
using
Descriptor1dType
=
remove_cvref_t
<
decltype
(
MakeMerge1d
(
Shape
{},
Flatten
DescriptorType
{}))
>
;
remove_cvref_t
<
decltype
(
MakeMerge1d
(
Shape
{},
Unrolled
DescriptorType
{}))
>
;
using
DefaultIdxsTupleType
=
remove_cvref_t
<
decltype
(
GenerateDefaultIdxsTuple
(
Shape
{}))
>
;
public:
/**
* \brief Transform descriptor to align to passed indexes.
*
* \param shape Layout shape.
* \param idxs Indexes to align descriptor.
* \param naive_descriptor Descriptor to merge.
* \return Aligned descriptor to idx.
*/
template
<
typename
...
ShapeDims
,
typename
...
IdxDims
>
__host__
__device__
constexpr
static
auto
TransformDesc
(
const
Tuple
<
ShapeDims
...
>&
shape
,
const
Tuple
<
IdxDims
...
>&
idx
,
const
Flatten
DescriptorType
&
naive_descriptor
)
const
Tuple
<
IdxDims
...
>&
idx
s
,
const
Unrolled
DescriptorType
&
naive_descriptor
)
{
if
constexpr
(
Tuple
<
IdxDims
...
>::
Size
()
==
I1
)
{
...
...
@@ -249,55 +276,38 @@ struct Layout
static_assert
(
Tuple
<
ShapeDims
...
>::
Size
()
==
Tuple
<
IdxDims
...
>::
Size
(),
"Idx rank and Shape rank must be the same (except 1d)."
);
// Unroll while IdxDims is nested
const
auto
aligned_shape
=
AlignShapeToIdx
(
shape
,
idx
);
const
auto
aligned_shape
=
AlignShapeToIdx
(
shape
,
idx
s
);
// Transform correct form of shape
return
CreateMergedDescriptor
(
aligned_shape
,
UnrollNestedTuple
(
idx
),
naive_descriptor
);
return
CreateMergedDescriptor
(
aligned_shape
,
UnrollNestedTuple
(
idx
s
),
naive_descriptor
);
}
}
using
MergedNestsDescriptorType
=
remove_cvref_t
<
decltype
(
TransformDesc
(
Shape
{},
DefaultIdxsTupleType
{},
Flatten
DescriptorType
{}))
>
;
Shape
{},
DefaultIdxsTupleType
{},
Unrolled
DescriptorType
{}))
>
;
public:
__host__
__device__
constexpr
auto
GetElementSpaceSize
()
const
{
return
flatten
_descriptor_
.
GetElementSpaceSize
();
return
unrolled
_descriptor_
.
GetElementSpaceSize
();
}
__host__
__device__
Layout
()
=
delete
;
/**
* \brief Layout constructor.
*
* \param shape Shape for layout.
* \param
strides Strides for layout (optional if tensor is packed).
* \param
unnested_descriptor Descriptor
*/
__host__
__device__
constexpr
Layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
:
flatten_descriptor_
{},
shape_
(
shape
),
strides_
(
strides
)
__host__
__device__
constexpr
Layout
(
const
Shape
&
shape
,
const
UnrolledDescriptorType
&
unnested_descriptor
)
:
unrolled_descriptor_
(
unnested_descriptor
),
shape_
(
shape
)
{
// Construct if runtime mode
if
constexpr
(
!
FlattenDescriptorType
::
IsKnownAtCompileTime
())
{
flatten_descriptor_
=
MakeFlattenDescriptor
(
shape_
,
strides_
);
descriptor_1d_
=
MakeMerge1d
(
shape_
,
flatten_descriptor_
);
merged_nests_descriptor_
=
TransformDesc
(
shape_
,
DefaultIdxsTupleType
{},
flatten_descriptor_
);
}
}
/**
* \brief Layout constructor (with default packed column-major strides).
*
* \param shape Shape for layout.
*/
__host__
__device__
constexpr
Layout
(
const
Shape
&
shape
)
:
flatten_descriptor_
{},
shape_
(
shape
),
strides_
(
GenerateColumnMajorPackedStrides
(
shape_
))
{
if
constexpr
(
!
FlattenDescriptorType
::
IsKnownAtCompileTime
())
if
constexpr
(
!
remove_cvref_t
<
UnrolledDescriptorType
>::
IsKnownAtCompileTime
())
{
flatten_descriptor_
=
MakeFlattenDescriptor
(
shape_
,
strides_
);
descriptor_1d_
=
MakeMerge1d
(
shape_
,
flatten_descriptor_
);
descriptor_1d_
=
MakeMerge1d
(
shape_
,
unrolled_descriptor_
);
merged_nests_descriptor_
=
TransformDesc
(
shape_
,
DefaultIdxsTupleType
{},
flatten
_descriptor_
);
TransformDesc
(
shape_
,
DefaultIdxsTupleType
{},
unrolled
_descriptor_
);
}
}
...
...
@@ -310,9 +320,9 @@ struct Layout
template
<
typename
Idxs
>
__host__
__device__
constexpr
index_t
operator
()()
const
{
static_assert
(
Flatten
DescriptorType
::
IsKnownAtCompileTime
(),
static_assert
(
remove_cvref_t
<
Unrolled
DescriptorType
>
::
IsKnownAtCompileTime
(),
"Compiletime operator used on runtime layout."
);
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{},
Flatten
DescriptorType
{}));
using
TransformedDesc
=
decltype
(
TransformDesc
(
Shape
{},
Idxs
{},
Unrolled
DescriptorType
{}));
using
UnrolledIdx
=
decltype
(
UnrollNestedTuple
(
Idxs
{}));
return
TransformedDesc
{}.
CalculateOffset
(
UnrolledIdx
{});
}
...
...
@@ -339,7 +349,7 @@ struct Layout
else
{
// Custom index, need to transform descriptor
const
auto
transformed_desc
=
TransformDesc
(
shape_
,
Idx
,
flatten
_descriptor_
);
const
auto
transformed_desc
=
TransformDesc
(
shape_
,
Idx
,
unrolled
_descriptor_
);
return
transformed_desc
.
CalculateOffset
(
UnrollNestedTuple
(
Idx
));
}
}
...
...
@@ -351,7 +361,7 @@ struct Layout
* \return Calculated size.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
GetLength
()
const
__host__
__device__
constexpr
auto
GetLength
()
const
{
const
auto
elem
=
shape_
.
At
(
Number
<
IDim
>
{});
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
IDim
,
Shape
>>::
value
)
...
...
@@ -371,7 +381,7 @@ struct Layout
*
* \return Calculated size.
*/
__host__
__device__
constexpr
index_t
GetLengths
()
const
__host__
__device__
constexpr
auto
GetLengths
()
const
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape_
);
return
TupleReduce
<
I0
.
value
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
...
...
@@ -385,13 +395,6 @@ struct Layout
*/
__host__
__device__
constexpr
const
Shape
&
GetShape
()
const
{
return
shape_
;
}
/**
* \brief Strides getter.
*
* \return Strides.
*/
__host__
__device__
constexpr
const
DeducedStrides
&
GetStrides
()
const
{
return
strides_
;
}
/**
* \brief Get default lengths (tuple filled with Shape length elements).
*
...
...
@@ -413,21 +416,56 @@ struct Layout
}
/**
* \brief Get default descriptor (with the same size as Shape)
* \brief Get descriptor with all nested dimensions merged.
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (4, 2)
*
* \note The size of merged descriptor is the same as Layout's shape.
*
* \return
Default
descriptor.
* \return
Merged nests
descriptor.
*/
__host__
__device__
constexpr
MergedNestsDescriptorType
GetDefaultDescriptor
()
__host__
__device__
constexpr
const
MergedNestsDescriptorType
&
GetMergedNestingDescriptor
()
const
{
return
merged_nests_descriptor_
;
}
/**
* \brief Get descriptor with all dimensions are merged (1D).
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (8)
*
* \return 1D descriptor.
*/
__host__
__device__
constexpr
const
Descriptor1dType
&
Get1DDescriptor
()
const
{
return
descriptor_1d_
;
}
/**
* \brief Get unnested descriptor (with unrolled dims)
* Example, shape: ((2, 2), 2)
* Descriptor lengths: (2, 2, 2)
*
* \return Flattened descriptor.
*/
__host__
__device__
constexpr
const
UnrolledDescriptorType
&
GetUnrolledDescriptor
()
const
{
return
unrolled_descriptor_
;
}
private:
FlattenDescriptorType
flatten_descriptor_
;
// All dimensions are unrolled
UnrolledDescriptorType
unrolled_descriptor_
;
// 1D descriptor
Descriptor1dType
descriptor_1d_
;
// All nesting are merged
MergedNestsDescriptorType
merged_nests_descriptor_
;
// Example, shape: ((2, 2), 2)
// UnrolledDescriptorType lengths: (2, 2, 2)
// Descriptor1dType lengths: (8)
// MergedNestsDescriptorType lengths: (4, 2)
const
Shape
shape_
;
const
DeducedStrides
strides_
;
};
}
// namespace wrapper
...
...
include/ck/wrapper/operations/copy.hpp
0 → 100644
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "../utils/tensor_utils.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
wrapper
{
/**
* \brief Perform generic copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template
<
typename
SrcTensorType
,
typename
DstTensorType
>
__host__
__device__
void
copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
)
{
if
constexpr
(
!
SrcTensorType
::
IsDynamicBuffer
)
{
using
SizeType
=
decltype
(
size
(
src_tensor
));
static_for
<
0
,
SizeType
{},
1
>
{}([
&
](
auto
i
)
{
dst_tensor
(
i
)
=
src_tensor
(
i
);
});
}
else
if
constexpr
(
!
DstTensorType
::
IsDynamicBuffer
)
{
using
SizeType
=
decltype
(
size
(
dst_tensor
));
static_for
<
0
,
SizeType
{},
1
>
{}([
&
](
auto
i
)
{
dst_tensor
(
i
)
=
src_tensor
(
i
);
});
}
else
{
for
(
int
i
=
0
;
i
<
size
(
src_tensor
);
i
++
)
{
dst_tensor
(
i
)
=
src_tensor
(
i
);
}
}
}
/**
* \brief Perform optimized copy between two tensors partitions (threadwise copy).
* Tensors must have the same size.
*
* \tparam DimAccessOrderTuple Tuple with dimension access order.
* \tparam VectorDim Dimension for vectorized read and write.
* \tparam ScalarPerVector Number of scalar per vectorized read and write.
* \param src_tensor Source tensor.
* \param dst_tensor Destination tensor.
*/
template
<
typename
DimAccessOrderTuple
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
SrcTensorType
,
typename
DstTensorType
>
__device__
void
copy
(
const
SrcTensorType
&
src_tensor
,
DstTensorType
&
dst_tensor
)
{
static_assert
(
is_detected
<
is_tuple
,
DimAccessOrderTuple
>::
value
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
const
auto
&
in_grid_desc
=
layout
(
src_tensor
).
GetUnrolledDescriptor
();
const
auto
&
out_grid_desc
=
layout
(
dst_tensor
).
GetUnrolledDescriptor
();
using
SrcShapeType
=
remove_cvref_t
<
decltype
(
shape
(
src_tensor
))
>
;
constexpr
index_t
num_dims
=
SrcShapeType
::
Size
();
constexpr
auto
thread_slice_lengths
=
generate_sequence_v2
([](
auto
I
)
{
return
size
(
SrcShapeType
{}.
At
(
I
));
},
Number
<
num_dims
>
{});
constexpr
auto
dim_access_order
=
generate_sequence_v2
(
[](
auto
I
)
{
return
DimAccessOrderTuple
{}.
At
(
I
);
},
Number
<
num_dims
>
{});
if
constexpr
(
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
)
{
// Perform a copy between DynamicBuffers
auto
transfer
=
ThreadwiseTensorSliceTransfer_v7
<
Tuple
<
typename
SrcTensorType
::
TensorElementType
>
,
Tuple
<
typename
DstTensorType
::
TensorElementType
>
,
decltype
(
tie
(
in_grid_desc
)),
decltype
(
tie
(
out_grid_desc
)),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
VectorDim
,
ScalarPerVector
,
Sequence
<
false
>
,
Sequence
<
false
>>
{
in_grid_desc
,
make_tuple
(
src_tensor
.
GetMultiIdxOffsets
()),
out_grid_desc
,
make_tuple
(
dst_tensor
.
GetMultiIdxOffsets
()),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
tie
(
in_grid_desc
),
tie
(
src_tensor
.
GetBuffer
()),
tie
(
out_grid_desc
),
tie
(
dst_tensor
.
GetBuffer
()));
}
else
if
constexpr
(
!
SrcTensorType
::
IsDynamicBuffer
&&
DstTensorType
::
IsDynamicBuffer
)
{
// Perform copy from StaticBuffer to DynamicBuffer
const
auto
src_slice_origin_idxs
=
generate_tuple
([
&
](
auto
)
{
return
I0
;
},
Number
<
num_dims
>
{});
auto
transfer
=
ThreadwiseTensorSliceTransfer_v1r3
<
typename
SrcTensorType
::
TensorElementType
,
typename
DstTensorType
::
TensorElementType
,
remove_cvref_t
<
decltype
(
in_grid_desc
)
>
,
remove_cvref_t
<
decltype
(
out_grid_desc
)
>
,
tensor_operation
::
element_wise
::
PassThrough
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
VectorDim
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
I1
,
true
>
{
out_grid_desc
,
dst_tensor
.
GetMultiIdxOffsets
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
transfer
.
Run
(
in_grid_desc
,
src_slice_origin_idxs
,
src_tensor
.
GetBuffer
(),
out_grid_desc
,
dst_tensor
.
GetBuffer
());
}
else
if
constexpr
(
SrcTensorType
::
IsDynamicBuffer
&&
!
DstTensorType
::
IsDynamicBuffer
)
{
// Perform copy from DynamicBuffer to StaticBuffer
const
auto
src_dst_slice_origin
=
generate_tuple
([
&
](
auto
)
{
return
I0
;
},
Number
<
num_dims
>
{});
constexpr
auto
src_vector_tensor_lengths
=
generate_sequence_v2
(
[
&
](
auto
I
)
{
if
constexpr
(
I
==
VectorDim
)
{
return
Number
<
ScalarPerVector
>
{};
}
else
{
return
I1
;
}
},
Number
<
num_dims
>
{});
auto
transfer
=
ThreadwiseTensorSliceTransfer_v4r1
<
typename
SrcTensorType
::
TensorElementType
,
typename
DstTensorType
::
TensorElementType
,
remove_cvref_t
<
decltype
(
in_grid_desc
)
>
,
remove_cvref_t
<
decltype
(
out_grid_desc
)
>
,
decltype
(
thread_slice_lengths
),
decltype
(
dim_access_order
),
decltype
(
src_vector_tensor_lengths
),
decltype
(
dim_access_order
)
>
{
src_tensor
.
GetMultiIdxOffsets
()};
transfer
.
Run
(
in_grid_desc
,
src_dst_slice_origin
,
src_tensor
.
GetBuffer
(),
out_grid_desc
,
src_dst_slice_origin
,
dst_tensor
.
GetBuffer
());
}
else
{
// Perform copy between StaticBuffers
copy
(
src_tensor
,
dst_tensor
);
}
}
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/tensor.hpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "utils/tensor_utils.hpp"
#include "utils/tensor_partition.hpp"
#include "utils/layout_utils.hpp"
namespace
ck
{
namespace
wrapper
{
namespace
detail
{
namespace
{
/**
* \brief
Tensor wrapper that performs static and dynamic buffer logic.
* \brief
Check if Tuple contains Slice object
*
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
* \tparam ElementType Element data type.
* \tparam Shape Tensor shape (layout component).
* \tparam Strides Tensor strides (layout component).
* \tparam NumVectors Number of vectors (only for VGPR, SGPR).
* \tparam ScalarPerVector Scalars per vector (only for VGPR, SGPR).
* \return True if tuple contains Slice object.
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
// param for Register memory
index_t
ScalarPerVector
// param for Register memory
>
struct
Tensor
template
<
typename
T
>
__host__
__device__
constexpr
bool
HasSlice
(
T
&&
)
{
private:
// Check if Tuple contains Slice object
template
<
typename
T
>
constexpr
static
bool
IsSlicing
(
T
&&
)
{
return
is_detected
<
is_slice
,
T
>::
value
;
}
template
<
typename
...
Ts
>
constexpr
static
bool
IsSlicing
(
Tuple
<
Ts
...
>&&
)
{
return
(
IsSlicing
(
Ts
{})
||
...);
}
// Calculate first index of new tensor after slice
// It is needed to calculate offset for new tensor
template
<
typename
...
Ts
>
constexpr
auto
GetStartIdxForSlicedTensor
(
const
Tuple
<
Ts
...
>&
idx
)
const
{
const
auto
start_idx_for_sliced_tensor
=
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// if tuple then recurrence
return
GetStartIdxForSlicedTensor
(
idx
.
At
(
num_i
));
}
else
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// if slice, return the beginning of the interval
return
idx
.
At
(
num_i
).
from_
;
}
else
{
// if one dim selected
return
idx
.
At
(
num_i
);
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
return
start_idx_for_sliced_tensor
;
}
}
template
<
typename
...
Ts
>
__host__
__device__
constexpr
bool
HasSlice
(
Tuple
<
Ts
...
>&&
)
{
return
(
HasSlice
(
Ts
{})
||
...);
}
// Calculate new tensor shape after slice
template
<
typename
...
Ts
,
typename
ShapeTmpType
>
constexpr
auto
GetShapeFromSlicedTensor
(
const
Tuple
<
Ts
...
>&
idx
,
const
ShapeTmpType
&
shape
)
const
{
/**
* \brief Calculate new shape after slice from parent shape.
*
* \param idxs Tuple of indexes defining slice ranges.
* \param shape Shape which will be sliced.
* \return New tensor shape.
*/
template
<
typename
...
Ts
,
typename
SlicedShape
>
__host__
__device__
constexpr
auto
GetSlicedShape
(
const
Tuple
<
Ts
...
>&
idxs
,
const
SlicedShape
&
shape
)
{
// Pack each value in tuple to remove empty tuples after generation
auto
new_shape
=
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
if
constexpr
(
!
I
sSlic
ing
(
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>
{}))
if
constexpr
(
!
detail
::
Ha
sSlic
e
(
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>
{}))
{
// if tuple does not have any slice then we can remove dimension
return
Tuple
<>
{};
...
...
@@ -90,15 +53,14 @@ struct Tensor
else
{
// if tuple then recurrence
return
make_tuple
(
GetS
hapeFromSlicedTensor
(
idx
.
At
(
num_i
),
shape
.
At
(
num_i
)));
return
make_tuple
(
GetS
licedShape
(
idx
s
.
At
(
num_i
),
shape
.
At
(
num_i
)));
}
}
else
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
else
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// calculate new dimension
const
auto
&
dim
=
size
(
shape
.
At
(
num_i
));
const
auto
val
=
idx
.
At
(
num_i
).
range
(
dim
);
const
auto
val
=
idx
s
.
At
(
num_i
).
range
(
dim
);
return
make_tuple
(
val
);
}
else
...
...
@@ -110,50 +72,143 @@ struct Tensor
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
// Remove empty tuples (deleted elements) and return
return
UnrollNestedTuple
<
0
,
1
>
(
new_shape
);
}
}
template
<
typename
...
Ts
,
typename
StridesTmpType
>
constexpr
auto
GetStridesFromSlicedTensor
(
const
Tuple
<
Ts
...
>&
idx
,
const
StridesTmpType
&
strides
)
const
{
/**
* \brief Generate Freeze for each of nested shape.
*
* \param idx Tuple of start indices for slice.
* \param shape Shape which will be freezed.
* \return Generated freeze transforms.
*/
template
<
typename
T
,
typename
Shape
>
__host__
__device__
constexpr
auto
GenerateMultipleFreeze
(
T
idx
,
const
Shape
&
shape
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
generate_tuple
(
[
&
](
auto
i
)
{
// dimension offset from idx
const
auto
dim
=
unrolled_shape
.
At
(
Number
<
i
>
{});
const
auto
dim_idx
=
idx
%
dim
;
idx
/=
dim
;
return
make_freeze_transform
(
dim_idx
);
},
Number
<
decltype
(
unrolled_shape
)
::
Size
()
>
{});
}
/**
* \brief Generate transforms for slice tensor.
*
* \param idx Tuple of start indices for slice.
* \param shape Shape which will be sliced.
* \return Generated transforms.
*/
template
<
typename
...
Ts
,
typename
Shape
>
__host__
__device__
constexpr
auto
GenerateSliceTransforms
(
const
Tuple
<
Ts
...
>&
idx
,
const
Shape
&
shape
)
{
// Pack each value in tuple to remove empty tuples after generation
auto
new_stride
s
=
generate_tuple
(
auto
transform
s
=
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
if
constexpr
(
is_detected
<
is_tuple
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
if
constexpr
(
!
IsSlicing
(
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>
{}))
{
// if tuple does not have any slice then we can remove dimension
return
Tuple
<>
{};
return
GenerateSliceTransforms
(
idx
.
At
(
num_i
),
shape
.
At
(
num_i
));
}
else
else
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// if tuple then recurrence
return
make_tuple
(
GetStridesFromSlicedTensor
(
idx
.
At
(
num_i
),
strides
.
At
(
num_i
)));
}
}
else
if
constexpr
(
is_detected
<
is_slice
,
tuple_element_t
<
i
.
value
,
Tuple
<
Ts
...
>>>::
value
)
{
// Stride will be the same
return
make_tuple
(
strides
.
At
(
num_i
));
const
auto
from
=
idx
.
At
(
num_i
).
from_
;
const
auto
dim
=
size
<
num_i
>
(
shape
);
const
auto
range
=
idx
.
At
(
num_i
).
range
(
dim
);
return
make_slice_transform
(
range
,
from
,
from
+
range
);
}
else
{
// remove dimension for just value
return
Tuple
<>
{}
;
return
GenerateMultipleFreeze
(
idx
.
At
(
num_i
),
shape
.
At
(
num_i
))
;
}
},
Number
<
Tuple
<
Ts
...
>::
Size
()
>
{});
// Remove empty tuples (deleted elements) and return
return
UnrollNestedTuple
<
0
,
1
>
(
new_strides
);
return
UnrollNestedTuple
(
transforms
);
}
template
<
index_t
i
,
typename
LowerIndex
>
__host__
__device__
constexpr
auto
GetSequenceVal
(
const
ck
::
Freeze
<
LowerIndex
>&
)
{
// There is no output for Freeze transform
return
Sequence
<>
{};
}
template
<
index_t
i
,
typename
LowLength
,
typename
SliceBegin
,
typename
SliceEnd
>
__host__
__device__
constexpr
auto
GetSequenceVal
(
const
ck
::
Slice
<
LowLength
,
SliceBegin
,
SliceEnd
>&
)
{
return
Sequence
<
i
>
{};
}
template
<
index_t
i
>
__host__
__device__
constexpr
auto
GenerateUpperDims
(
const
Tuple
<>&
)
{
return
Tuple
<>
{};
}
template
<
index_t
i
,
typename
...
Transforms
>
__host__
__device__
constexpr
auto
GenerateUpperDims
(
const
Tuple
<
Transforms
...
>&
transforms
)
{
constexpr
auto
num_transforms
=
Tuple
<
Transforms
...
>::
Size
();
// Deduce Sequence element for specific transform
const
auto
current_elem
=
GetSequenceVal
<
i
>
(
transforms
.
At
(
Number
<
0
>
{}));
if
constexpr
(
is_same_v
<
decltype
(
current_elem
),
const
Sequence
<>>
)
{
const
auto
next_tuple
=
GenerateUpperDims
<
i
>
(
TupleSlice
<
1
,
num_transforms
>
(
transforms
));
return
concat_tuple
(
make_tuple
(
current_elem
),
next_tuple
);
}
else
{
// Increase i if current_elem is Slice transform
const
auto
next_tuple
=
GenerateUpperDims
<
i
+
1
>
(
TupleSlice
<
1
,
num_transforms
>
(
transforms
));
return
concat_tuple
(
make_tuple
(
current_elem
),
next_tuple
);
}
}
template
<
typename
...
Ts
,
typename
Shape
,
typename
FlattenDescriptor
>
__host__
__device__
constexpr
auto
GenerateSlicedDescriptor
(
const
Tuple
<
Ts
...
>&
idx
,
const
Shape
&
shape
,
const
FlattenDescriptor
&
flatten_desc
)
{
constexpr
auto
old_shape_dims
=
decltype
(
UnrollNestedTuple
(
shape
))
::
Size
();
const
auto
transforms
=
GenerateSliceTransforms
(
idx
,
shape
);
using
TransformsTupleType
=
decltype
(
transforms
);
const
auto
lower_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
old_shape_dims
>
{});
const
auto
upper_dims
=
decltype
(
GenerateUpperDims
<
0
>
(
TransformsTupleType
{})){};
return
transform_tensor_descriptor
(
flatten_desc
,
transforms
,
lower_dims
,
upper_dims
);
}
}
// namespace
}
// namespace detail
/**
* \brief Tensor wrapper that performs static and dynamic buffer logic.
* The tensor is based on a descriptor stored in the Layout. Additionally,
* tensor can be sliced or shifted using multi-index offset.
*
* \tparam BufferAddressSpace Memory type (Generic, Global, LDS, VGPR, SGPR).
* \tparam ElementType Element data type.
* \tparam Shape Tensor shape (layout component).
* \tparam UnrolledDescriptorType Flatten descriptor (layout component).
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
UnrolledDescriptorType
>
struct
Tensor
{
public:
using
ElementSpaceSize
=
decltype
(
Layout
<
Shape
,
Strides
>
{
Shape
{},
Strides
{}}.
GetElementSpaceSize
());
// SpaceSize type for buffer
using
ElementSpaceSize
=
decltype
(
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
GetElementSpaceSize
());
// SpaceSize type for buffer
using
TensorElementType
=
ElementType
;
// DataType
static
constexpr
MemoryTypeEnum
TensorBufferAddressSpace
=
BufferAddressSpace
;
...
...
@@ -161,135 +216,178 @@ struct Tensor
BufferAddressSpace
==
MemoryTypeEnum
::
Vgpr
);
__host__
__device__
Tensor
()
=
delete
;
__host__
__device__
Tensor
(
ElementType
*
pointer
,
const
Layout
<
Shape
,
Strides
>&
layout
)
__host__
__device__
constexpr
Tensor
(
ElementType
*
pointer
,
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
:
layout_
(
layout
),
buffer_
(
make_dynamic_buffer
<
BufferAddressSpace
>
(
pointer
,
layout
.
GetElementSpaceSize
()))
buffer_
(
make_dynamic_buffer
<
BufferAddressSpace
>
(
pointer
,
layout
.
GetElementSpaceSize
())),
multi_idx_offset_
(
make_zero_multi_index
<
Shape
::
Size
()
>
()),
base_offset_
(
0
)
{
static_assert
(
IsDynamicBuffer
,
"Wrong BufferAddressSpace for register."
);
}
__host__
__device__
Tensor
(
const
Layout
<
Shape
,
Strides
>&
layout
)
:
layout_
(
layout
)
__host__
__device__
constexpr
Tensor
(
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
:
layout_
(
layout
),
multi_idx_offset_
(
make_zero_multi_index
<
Shape
::
Size
()
>
()),
base_offset_
(
0
)
{
static_assert
(
!
IsDynamicBuffer
,
"Wrong BufferAddressSpace for register."
);
}
__host__
__device__
constexpr
const
Layout
<
Shape
,
Strides
>&
GetLayout
()
const
__host__
__device__
constexpr
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
GetLayout
()
const
{
return
layout_
;
}
// Getter for new sliced tensor
template
<
typename
...
Ts
,
enable_if_t
<
IsSlicing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
const
/**
* \brief Get the new sliced tensor.
*
* \param idx Tuple of indices: slice(from,to) or scalar.
* \return Sliced tensor.
*/
template
<
typename
...
Ts
,
enable_if_t
<
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
{
static_assert
(
IsDynamicBuffer
,
"Register slice is not supported"
);
// Calculate offset based on first idx for new tensor
const
index_t
offset
=
layout_
(
GetStartIdxForSlicedTensor
(
idx
)
);
const
auto
&
shape
=
layout_
.
GetShape
();
auto
new_shape
=
detail
::
GetSlicedShape
(
idx
,
shape
);
auto
new_shape
=
GetShapeFromSlicedTensor
(
idx
,
layout_
.
GetShape
());
if
constexpr
(
is_same_v
<
Strides
,
Tuple
<>>
)
{
auto
new_layout
=
make_layout
(
new_shape
);
return
make_tensor
<
BufferAddressSpace
>
(
buffer_
.
p_data_
+
offset
,
new_layout
);
}
else
{
auto
new_strides
=
GetStridesFromSlicedTensor
(
idx
,
layout_
.
GetStrides
());
auto
new_layout
=
make_layout
(
new_shape
,
new_strides
);
return
make_tensor
<
BufferAddressSpace
>
(
buffer_
.
p_data_
+
offset
,
new_layout
);
}
const
auto
&
flatten_desc
=
layout_
.
GetUnrolledDescriptor
();
auto
new_desc
=
detail
::
GenerateSlicedDescriptor
(
idx
,
shape
,
flatten_desc
);
const
auto
new_layout
=
Layout
<
decltype
(
new_shape
),
decltype
(
new_desc
)
>
(
new_shape
,
new_desc
);
// Update embed offset
base_offset_
-=
new_layout
(
make_tuple
(
Number
<
0
>
{}));
return
make_tensor
<
BufferAddressSpace
>
(
buffer_
.
p_data_
,
new_layout
);
}
template
<
typename
...
Ts
,
enable_if_t
<
I
sSlic
ing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
const
template
<
typename
...
Ts
,
enable_if_t
<
detail
::
Ha
sSlic
e
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<
I
sSlic
ing
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
()(
Idxs
...
idxs
)
const
template
<
typename
...
Idxs
,
enable_if_t
<
detail
::
Ha
sSlic
e
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
auto
operator
()(
Idxs
...
idxs
)
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
// Getter for the const value
template
<
typename
...
Ts
,
enable_if_t
<!
IsSlicing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
/**
* \brief Getter of the tensor's const value reference.
*
* \param idx Tuple of indices.
* \return Requested value.
*/
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
const
ElementType
&
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
const
{
if
constexpr
(
IsDynamicBuffer
)
{
const
index_t
offset
=
layout_
(
idx
);
const
index_t
offset
=
layout_
(
idx
)
+
base_offset_
;
return
buffer_
[
offset
];
}
else
{
if
constexpr
(
is_same_v
<
Strides
,
Tuple
<>>
)
{
constexpr
index_t
offset
=
Layout
<
Shape
,
Strides
>
{
Shape
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
return
buffer_
[
Number
<
offset
>
{}];
}
else
{
constexpr
index_t
offset
=
Layout
<
Shape
,
Strides
>
{
Shape
{},
Strides
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
return
buffer_
[
Number
<
offset
>
{}];
}
constexpr
index_t
index_offset
=
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
// Calculate and apply base offset in compile-time
constexpr
index_t
base_offset
=
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
template
operator
()
<
MultiIndex
<
Shape
::
Size
()>
>
();
return
buffer_
[
Number
<
index_offset
+
base_offset
>
{}];
}
}
template
<
typename
...
Ts
,
enable_if_t
<!
I
sSlic
ing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
Ha
sSlic
e
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
const
ElementType
&
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
const
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<!
I
sSlic
ing
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
template
<
typename
...
Idxs
,
enable_if_t
<!
detail
::
Ha
sSlic
e
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
const
ElementType
&
operator
()(
Idxs
...
idxs
)
const
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
// Getter for the value reference
template
<
typename
...
Ts
,
enable_if_t
<!
IsSlicing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
/**
* \brief Getter of tensor value reference.
*
* \param idx Tuple of indices.
* \return Requested value.
*/
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
HasSlice
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
ElementType
&
operator
[](
const
Tuple
<
Ts
...
>&
idx
)
{
if
constexpr
(
IsDynamicBuffer
)
{
const
index_t
offset
=
layout_
(
idx
);
const
index_t
offset
=
layout_
(
idx
)
+
base_offset_
;
return
buffer_
(
offset
);
}
else
{
if
constexpr
(
is_same_v
<
Strides
,
Tuple
<>>
)
{
constexpr
index_t
offset
=
Layout
<
Shape
,
Strides
>
{
Shape
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
return
buffer_
(
Number
<
offset
>
{});
}
else
{
constexpr
index_t
offset
=
Layout
<
Shape
,
Strides
>
{
Shape
{},
Strides
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
return
buffer_
(
Number
<
offset
>
{});
}
constexpr
index_t
index_offset
=
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
template
operator
()
<
Tuple
<
Ts
...>
>
();
// Apply embed offset (calculate in compiletime)
constexpr
index_t
base_offset
=
Layout
<
Shape
,
UnrolledDescriptorType
>
{
Shape
{},
UnrolledDescriptorType
{}}.
template
operator
()
<
MultiIndex
<
Shape
::
Size
()>
>
();
return
buffer_
(
Number
<
index_offset
+
base_offset
>
{});
}
}
template
<
typename
...
Ts
,
enable_if_t
<!
I
sSlic
ing
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
template
<
typename
...
Ts
,
enable_if_t
<!
detail
::
Ha
sSlic
e
(
Tuple
<
Ts
...>{}),
bool
>
=
false
>
__host__
__device__
ElementType
&
operator
()(
const
Tuple
<
Ts
...
>&
idx
)
{
return
this
->
operator
[](
idx
);
}
template
<
typename
...
Idxs
,
enable_if_t
<!
I
sSlic
ing
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
template
<
typename
...
Idxs
,
enable_if_t
<!
detail
::
Ha
sSlic
e
(
Tuple
<
Idxs
...>{}),
bool
>
=
false
>
__host__
__device__
ElementType
&
operator
()(
Idxs
...
idxs
)
{
return
this
->
operator
[](
make_tuple
(
idxs
...));
}
__host__
__device__
constexpr
auto
GetDefaultDescriptor
()
/**
* \brief Get descriptor with all nested dimensions merged.
*
* \return Merged nests descriptor.
*/
__host__
__device__
constexpr
auto
GetMergedNestingDescriptor
()
{
return
layout_
.
GetMergedNestingDescriptor
();
}
/**
* \brief Get pointer to the data.
*
* \return Pointer.
*/
__host__
__device__
ElementType
*
GetPointer
()
const
{
return
buffer_
.
p_data_
;
}
__host__
__device__
constexpr
auto
&
GetBuffer
()
{
return
buffer_
;
}
__host__
__device__
constexpr
auto
&
GetBuffer
()
const
{
return
buffer_
;
}
/**
* \brief Get multi index offset to the data.
*
* \return Multi index offset.
*/
__host__
__device__
constexpr
auto
&
GetMultiIdxOffsets
()
const
{
return
multi_idx_offset_
;
}
/**
* \brief Apply multi index offset on the tensor.
*
* \param multi_idx_offset Multi index offset.
*/
template
<
typename
MultiIdxOffsets
>
__host__
__device__
constexpr
void
SetMultiIdxOffset
(
const
MultiIdxOffsets
multi_idx_offset
)
{
return
layout_
.
GetDefaultDescriptor
();
multi_idx_offset_
=
multi_idx_offset
;
base_offset_
+=
layout_
(
multi_idx_offset
);
}
private:
...
...
@@ -297,17 +395,28 @@ struct Tensor
ElementType
,
ElementSpaceSize
,
true
/*InvalidElementUseNumericalZeroValue*/
>
;
using
StaticBufferType
=
StaticBufferTupleOfVector
<
BufferAddressSpace
,
using
StaticBufferType
=
StaticBuffer
<
BufferAddressSpace
,
ElementType
,
NumVectors
,
ScalarPerVector
,
size
(
Shape
{}),
true
/*InvalidElementUseNumericalZeroValue*/
>
;
// If register use static buffer, else use dynamic buffer
using
Buffer
=
std
::
conditional_t
<
IsDynamicBuffer
,
DynamicBufferType
,
StaticBufferType
>
;
const
Layout
<
Shape
,
Strides
>
layout_
;
const
Layout
<
Shape
,
UnrolledDescriptorType
>
layout_
;
Buffer
buffer_
;
// We use multi_idx_offset_ to enable the creation of a descriptor in
// compile time for partitions or tiles if tile shape and thread layout
// is known at compile time (We can use the same descriptor for each
// thread). Additionally, the copy between the static and dynamic buffer
// requires a descriptor known at compile time, so we can shift data using
// such multi_idx_offset_.
MultiIndex
<
Shape
::
Size
()
>
multi_idx_offset_
;
// Base offset and multi index offset are corresponding to exactly the
// same element in tensor ( and in physical memory ). Multi index offset
// is multi dimensional index. However base offset is calculated using
// tensor descriptor (thus all it's transforms) and is linear (1D).
// We store base_offset_ to avoid multiple recalculations.
index_t
base_offset_
;
};
}
// namespace wrapper
...
...
include/ck/wrapper/utils/layout_utils.hpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -22,11 +22,69 @@ namespace wrapper {
// Disable from doxygen docs generation
/// @cond
// forward declaration
template
<
typename
Shape
,
typename
Strides
>
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
struct
Layout
;
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
namespace
{
/**
* \brief Generate packed (column-major) strides if not passed
*
* \param shape Tensor shape.
* \return Generated column-major strides.
*/
template
<
typename
...
Ts
>
__host__
__device__
constexpr
static
auto
GenerateColumnMajorPackedStrides
(
const
Tuple
<
Ts
...
>&
shape
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
.
value
==
0
)
{
return
Number
<
1
>
{};
}
else
{
return
TupleReduce
<
Number
<
0
>
{}.
value
,
i
.
value
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
unrolled_shape
);
}
},
Number
<
decltype
(
unrolled_shape
)
::
Size
()
>
{});
}
/**
* \brief Create naive tensor descriptor from nested shape.
*
* \param shape Tensor shape.
* \param strides Tensor strides.
* \return Unrolled descriptor
*/
template
<
typename
LayoutShape
,
typename
LayoutStrides
>
__host__
__device__
constexpr
auto
MakeUnrolledDescriptor
(
const
LayoutShape
&
shape
,
const
LayoutStrides
&
strides
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
if
constexpr
(
is_same_v
<
LayoutStrides
,
Tuple
<>>
)
{
// if not passed, then generate
const
auto
unrolled_strides
=
GenerateColumnMajorPackedStrides
(
unrolled_shape
);
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
"Size of strides and shape are not consistent."
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
}
else
{
const
auto
unrolled_strides
=
UnrollNestedTuple
(
strides
);
static_assert
(
unrolled_shape
.
Size
()
==
unrolled_strides
.
Size
(),
"Size of strides and shape are not consistent."
);
return
make_naive_tensor_descriptor
(
unrolled_shape
,
unrolled_strides
);
}
}
}
// namespace
/// @endcond
// make_*
...
...
@@ -38,10 +96,10 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \return Constructed layout.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
Layout
<
Shape
,
Strides
>
make_layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
__host__
__device__
constexpr
auto
make_layout
(
const
Shape
&
shape
,
const
Strides
&
strides
)
{
return
Layout
<
Shape
,
Strides
>
(
shape
,
strides
);
using
UnrolledDescriptorType
=
decltype
(
MakeUnrolledDescriptor
(
Shape
{},
Strides
{}));
return
Layout
<
Shape
,
UnrolledDescriptorType
>
(
shape
,
MakeUnrolledDescriptor
(
shape
,
strides
));
}
/**
...
...
@@ -52,16 +110,21 @@ __host__ __device__ constexpr Layout<Shape, Strides> make_layout(const Shape& sh
* \return Constructed layout.
*/
template
<
typename
Shape
>
__host__
__device__
constexpr
Layout
<
Shape
,
Tuple
<>>
make_layout
(
const
Shape
&
shape
)
__host__
__device__
constexpr
auto
make_layout
(
const
Shape
&
shape
)
{
return
Layout
<
Shape
,
Tuple
<>>
(
shape
);
using
UnrolledDescriptorType
=
decltype
(
MakeUnrolledDescriptor
(
Shape
{},
Tuple
<>
{}));
return
Layout
<
Shape
,
UnrolledDescriptorType
>
(
shape
,
MakeUnrolledDescriptor
(
shape
,
Tuple
<>
{}));
}
// Layout helpers
// get
// Get dim (could be returned from get with empty Idxs)
/**
* \private
* \brief Get dim.
*
* \param dim Dimension.
* \return Returned the same dimension.
*/
template
<
typename
T
>
__host__
__device__
T
constexpr
get
(
const
T
&
dim
)
...
...
@@ -89,26 +152,51 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template
<
index_t
idx
,
typename
Shape
,
typename
Strid
es
>
__host__
__device__
constexpr
auto
get
(
const
Layout
<
Shape
,
Strid
es
>&
layout
)
template
<
index_t
idx
,
typename
Shape
,
typename
FlattenD
es
c
>
__host__
__device__
constexpr
auto
get
(
const
Layout
<
Shape
,
FlattenD
es
c
>&
layout
)
{
const
auto
&
shape
=
layout
.
GetShape
();
const
auto
&
new_shape
=
get
<
idx
>
(
shape
);
const
auto
new_shape
=
get
<
idx
>
(
shape
);
static_assert
(
is_detected
<
is_tuple
,
decltype
(
new_shape
)
>::
value
,
"Shape of sub layout must be tuple"
);
if
constexpr
(
is_same_v
<
Strides
,
Tuple
<>>
)
constexpr
auto
old_shape_dims
=
decltype
(
UnrollNestedTuple
(
shape
))
::
Size
();
constexpr
auto
new_shape_dims
=
decltype
(
UnrollNestedTuple
(
new_shape
))
::
Size
();
constexpr
auto
shape_offset
=
decltype
(
UnrollNestedTuple
(
TupleSlice
<
0
,
idx
>
(
shape
)))
::
Size
();
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
i
)
{
// Compare Idx with shape
if
constexpr
(
i
<
shape_offset
||
i
>=
shape_offset
+
new_shape_dims
)
{
// Remove dimension
return
make_freeze_transform
(
Number
<
0
>
{});
}
else
{
// If stride not passed, create without strides
return
make_layout
(
new_shape
);
return
make_pass_through_transform
(
unrolled_shape
.
At
(
i
));
}
},
Number
<
old_shape_dims
>
{});
const
auto
lower_dims
=
generate_tuple
([
&
](
auto
i
)
{
return
Sequence
<
i
.
value
>
{};
},
Number
<
old_shape_dims
>
{});
const
auto
upper_dims
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
<
shape_offset
||
i
>=
shape_offset
+
new_shape_dims
)
return
Sequence
<>
{};
else
{
const
auto
&
strides
=
layout
.
GetStrides
();
const
auto
&
new_strides
=
get
<
idx
>
(
strides
);
static_assert
(
is_detected
<
is_tuple
,
decltype
(
new_strides
)
>::
value
,
"Strides of sub layout must be tuple"
);
return
make_layout
(
new_shape
,
new_strides
);
return
Sequence
<
i
.
value
-
shape_offset
>
{};
}
},
Number
<
old_shape_dims
>
{});
const
auto
&
flatten_desc
=
layout
.
GetUnrolledDescriptor
();
auto
new_desc
=
transform_tensor_descriptor
(
flatten_desc
,
transforms
,
lower_dims
,
upper_dims
);
return
Layout
<
decltype
(
new_shape
),
decltype
(
new_desc
)
>
(
new_shape
,
new_desc
);
}
/**
...
...
@@ -125,9 +213,12 @@ __host__ __device__ constexpr auto get(const T& elem)
}
// size
// Get dim size (could be returned from get function)
/**
* \private
* \brief Get size.
*
* \param dim Size.
* \return Returned the same size.
*/
template
<
typename
T
>
__host__
__device__
T
constexpr
size
(
const
T
&
dim
)
...
...
@@ -142,8 +233,8 @@ __host__ __device__ T constexpr size(const T& dim)
* \param layout Layout to get Shape of.
* \return Requsted length.
*/
template
<
index_t
idx
,
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
template
<
index_t
idx
,
typename
Shape
,
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
size
(
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
return
layout
.
template
GetLength
<
idx
>();
}
...
...
@@ -155,7 +246,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
* \return Requsted size.
*/
template
<
typename
...
ShapeDims
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
__host__
__device__
constexpr
auto
size
(
const
Tuple
<
ShapeDims
...
>&
shape
)
{
const
auto
unrolled_shape
=
UnrollNestedTuple
(
shape
);
return
TupleReduce
<
0
,
unrolled_shape
.
Size
()
>
([](
auto
x
,
auto
y
)
{
return
x
*
y
;
},
...
...
@@ -168,8 +259,8 @@ __host__ __device__ constexpr index_t size(const Tuple<ShapeDims...>& shape)
* \param layout Layout to calculate shape size.
* \return Requsted size.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
index_t
size
(
const
Layout
<
Shape
,
Strides
>&
layout
)
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
size
(
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
return
layout
.
GetLengths
();
}
...
...
@@ -182,7 +273,7 @@ __host__ __device__ constexpr index_t size(const Layout<Shape, Strides>& layout)
* \return Requsted length.
*/
template
<
index_t
idx
,
typename
...
Ts
>
__host__
__device__
constexpr
index_t
size
(
const
Tuple
<
Ts
...
>&
tuple
)
__host__
__device__
constexpr
auto
size
(
const
Tuple
<
Ts
...
>&
tuple
)
{
return
size
(
tuple
.
At
(
Number
<
idx
>
{}));
}
...
...
@@ -208,8 +299,9 @@ __host__ __device__ constexpr auto size(const T& elem)
* \param layout Layout to calculate rank.
* \return Requsted rank.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
rank
([[
maybe_unused
]]
const
Layout
<
Shape
,
Strides
>&
layout
)
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
rank
([[
maybe_unused
]]
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
return
Shape
::
Size
();
}
...
...
@@ -229,17 +321,25 @@ __host__ __device__ constexpr auto rank([[maybe_unused]] const Tuple<Dims...>& t
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
rank
(
const
Number
<
IDim
>&
)
__host__
__device__
constexpr
index_t
rank
(
[[
maybe_unused
]]
const
Number
<
IDim
>&
dim
)
{
return
1
;
}
/**
* \private
* \brief Rank for scalar
*
* \param dim Dimension scalar.
* \return Returned 1.
*/
__host__
__device__
constexpr
index_t
rank
(
const
index_t
&
)
{
return
1
;
}
__host__
__device__
constexpr
index_t
rank
(
[[
maybe_unused
]]
const
index_t
&
dim
)
{
return
1
;
}
/**
* \brief Hierarchical rank.
...
...
@@ -261,8 +361,8 @@ __host__ __device__ constexpr auto rank(const T& elem)
* \param layout Layout to calculate depth.
* \return Requsted depth.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
auto
depth
(
const
Layout
<
Shape
,
Strides
>&
layout
)
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
depth
(
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
const
auto
&
shape
=
layout
.
GetShape
();
return
TupleDepth
(
shape
);
...
...
@@ -282,17 +382,25 @@ __host__ __device__ constexpr auto depth(const Tuple<Dims...>& tuple)
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
template
<
index_t
IDim
>
__host__
__device__
constexpr
index_t
depth
(
const
Number
<
IDim
>&
)
__host__
__device__
constexpr
index_t
depth
(
[[
maybe_unused
]]
const
Number
<
IDim
>&
dim
)
{
return
0
;
}
/**
* \private
* \brief Depth for scalar
*
* \param dim Scalar.
* \return Returned 0.
*/
__host__
__device__
constexpr
index_t
depth
(
const
index_t
&
)
{
return
0
;
}
__host__
__device__
constexpr
index_t
depth
(
[[
maybe_unused
]]
const
index_t
&
dim
)
{
return
0
;
}
/**
* \brief Hierarchical depth.
...
...
@@ -307,26 +415,14 @@ __host__ __device__ constexpr auto depth(const T& elem)
return
depth
(
get
<
Idxs
...
>
(
elem
));
}
/**
* \brief Get Layout strides.
*
* \param layout Layout to get strides from.
* \return Requsted strides.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
const
auto
&
stride
(
const
Layout
<
Shape
,
Strides
>&
layout
)
{
return
layout
.
GetStrides
();
}
/**
* \brief Get Layout shape.
*
* \param layout Layout to get shape from.
* \return Requsted shape.
*/
template
<
typename
Shape
,
typename
Strides
>
__host__
__device__
constexpr
const
auto
&
shape
(
const
Layout
<
Shape
,
Strides
>
&
layout
)
template
<
typename
LayoutType
>
__host__
__device__
constexpr
const
auto
&
shape
(
const
Layout
Type
&
layout
)
{
return
layout
.
GetShape
();
}
...
...
include/ck/wrapper/utils/tensor_partition.hpp
0 → 100644
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
namespace
ck
{
namespace
wrapper
{
namespace
{
/**
* \brief Calculate shape for partition based on number of threads per each dim and
* previous shape
*
* \param shape Base tensor shape.
* \param thread_lengths Tuple of thread lengths.
* \return Partition shape.
*/
template
<
typename
...
Ts
,
typename
...
Ls
>
__host__
__device__
constexpr
auto
CalculateLocalPartitionShape
(
const
Tuple
<
Ts
...
>&
shape
,
const
Tuple
<
Ls
...
>&
thread_lengths
)
{
static_assert
(
Tuple
<
Ts
...
>::
Size
()
==
Tuple
<
Ls
...
>::
Size
(),
"Wrong thread_lengths shape."
);
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
num_i
=
Number
<
i
>
{};
const
auto
slice_len
=
size
<
num_i
>
(
shape
)
/
thread_lengths
.
At
(
num_i
);
return
slice_len
;
},
Number
<
Tuple
<
Ls
...
>::
Size
()
>
{});
}
/**
* \brief Calculate total number of blocks.
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \return Tuple with blocks number.
*/
template
<
typename
...
Ts
,
typename
...
Ls
>
__host__
__device__
constexpr
auto
CalculateGridSize
(
const
Tuple
<
Ts
...
>&
shape
,
const
Tuple
<
Ls
...
>&
tile_shape
)
{
static_assert
(
Tuple
<
Ts
...
>::
Size
()
==
Tuple
<
Ls
...
>::
Size
(),
"Wrong thread_lengths shape."
);
return
generate_tuple
([
&
](
auto
i
)
{
return
size
<
i
>
(
shape
)
/
size
<
i
>
(
tile_shape
);
},
Number
<
Tuple
<
Ls
...
>::
Size
()
>
{});
}
/**
* \brief Calculate scaled offset for new partition/tile.
*
* \param thread_idxs Thread 1d id.
* \param partition_lengths_seq Sequence of partition shape.
* \param old_offset_idxs Multi index offset from base tensor to shift values.
* \return Partition shape.
*/
template
<
typename
ThreadIdxs
,
typename
PartitionLengthsSeq
,
typename
OldOffsetIdxs
>
__host__
__device__
constexpr
auto
CalculateOffsetMultiIdxs
(
const
ThreadIdxs
&
thread_idxs
,
const
PartitionLengthsSeq
&
partition_lengths_seq
,
const
OldOffsetIdxs
&
old_offset_idxs
)
{
return
thread_idxs
*
partition_lengths_seq
+
old_offset_idxs
;
}
}
// namespace
/**
* \brief Create local partition for thread (At now only packed partition
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \return Partition tensor.
*/
template
<
typename
TensorType
,
typename
ThreadLengthsTuple
>
__host__
__device__
constexpr
auto
make_local_partition
(
TensorType
&
tensor
,
[[
maybe_unused
]]
const
ThreadLengthsTuple
&
thread_lengths
,
const
index_t
thread_id
)
{
static_assert
(
!
IsNestedTuple
(
ThreadLengthsTuple
{}));
// Calculate new partition shape
const
auto
&
tensor_shape
=
shape
(
tensor
);
constexpr
auto
partition_shape
=
CalculateLocalPartitionShape
(
decltype
(
tensor_shape
){},
ThreadLengthsTuple
{});
// Create Thread Cluster Descriptor
constexpr
auto
partition_lengths_seq
=
generate_sequence_v2
(
[
&
](
auto
I
)
{
return
size
<
I
>
(
partition_shape
);
},
Number
<
ThreadLengthsTuple
::
Size
()
>
{});
constexpr
auto
thread_lengths_seq
=
generate_sequence_v2
([
&
](
auto
I
)
{
return
size
<
I
>
(
ThreadLengthsTuple
{});
},
Number
<
ThreadLengthsTuple
::
Size
()
>
{});
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
thread_lengths_seq
);
// Calculate thread idxs and offsets
const
auto
thread_idxs
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
const
auto
offset_multi_idxs
=
CalculateOffsetMultiIdxs
(
thread_idxs
,
partition_lengths_seq
,
tensor
.
GetMultiIdxOffsets
());
// Create new layout and tensor
auto
&
flatten_desc
=
layout
(
tensor
).
GetUnrolledDescriptor
();
const
auto
partition_layout
=
Layout
<
remove_reference_t
<
decltype
(
partition_shape
)
>
,
decltype
(
flatten_desc
)
>
(
partition_shape
,
flatten_desc
);
auto
partition_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
partition_layout
);
// Apply offsets
partition_tensor
.
SetMultiIdxOffset
(
to_multi_index
(
offset_multi_idxs
));
return
partition_tensor
;
}
/**
* \brief Create local tile for thread block. (At now only packed tile
* is supported).
*
* \note Temporary to gain the best performance use 2d
* tile_shape.
*
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \return Tile tensor.
*/
template
<
typename
TensorType
,
typename
BlockShapeTuple
>
__host__
__device__
constexpr
auto
make_local_tile
(
const
TensorType
&
tensor
,
const
BlockShapeTuple
&
tile_shape
,
const
index_t
block_id
)
{
static_assert
(
!
IsNestedTuple
(
BlockShapeTuple
{}));
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
auto
&
aligned_desc
=
layout
(
tensor
).
GetMergedNestingDescriptor
();
if
constexpr
(
BlockShapeTuple
::
Size
()
==
I2
)
{
// Optimized version for 2d tile shape [MxK]
const
auto
block_2_tile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
BlockShapeTuple
{}.
At
(
I0
),
BlockShapeTuple
{}.
At
(
I1
),
remove_cvref_t
<
decltype
(
aligned_desc
)
>>
(
aligned_desc
);
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_id
));
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
size
<
0
>
(
tile_shape
));
const
index_t
k_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
size
<
1
>
(
tile_shape
));
const
auto
offset_multi_idxs
=
make_tuple
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
);
// Create new layout and tensor
const
auto
tile_layout
=
Layout
<
remove_reference_t
<
decltype
(
tile_shape
)
>
,
decltype
(
aligned_desc
)
>
(
tile_shape
,
aligned_desc
);
auto
tile_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
tile_layout
);
// Apply offsets
tile_tensor
.
SetMultiIdxOffset
(
to_multi_index
(
offset_multi_idxs
));
return
tile_tensor
;
}
else
{
// Calculate offsets
// Sequence with data to process per block
constexpr
auto
tile_shape_seq
=
generate_sequence_v2
([](
auto
I
)
{
return
size
(
BlockShapeTuple
{}.
At
(
I
));
},
Number
<
BlockShapeTuple
::
Size
()
>
{});
// Tuple with number of blocks
const
auto
block_lengths
=
CalculateGridSize
(
shape
(
tensor
),
tile_shape
);
constexpr
auto
block_cluster_desc_
=
make_cluster_descriptor
(
block_lengths
);
const
auto
block_idxs
=
block_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
block_id
));
const
auto
offset_multi_idxs
=
CalculateOffsetMultiIdxs
(
block_idxs
,
tile_shape_seq
,
tensor
.
GetMultiIdxOffsets
());
// Create new layout and tensor
const
auto
tile_layout
=
Layout
<
remove_reference_t
<
decltype
(
tile_shape
)
>
,
decltype
(
aligned_desc
)
>
(
tile_shape
,
aligned_desc
);
auto
tile_tensor
=
make_tensor
<
TensorType
::
TensorBufferAddressSpace
>
(
tensor
.
GetPointer
(),
tile_layout
);
// Apply offsets
tile_tensor
.
SetMultiIdxOffset
(
to_multi_index
(
offset_multi_idxs
));
return
tile_tensor
;
}
}
}
// namespace wrapper
}
// namespace ck
include/ck/wrapper/utils/tensor_utils.hpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -10,6 +10,7 @@
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/dynamic_buffer.hpp"
#include "ck/utility/amd_address_space.hpp"
#include "ck/utility/multi_index.hpp"
namespace
ck
{
namespace
wrapper
{
...
...
@@ -27,16 +28,12 @@ using MemoryTypeEnum = AddressSpaceEnum;
// Disable from doxygen docs generation
/// @cond
// forward declarations
template
<
typename
Shape
,
typename
Strides
>
template
<
typename
Shape
,
typename
UnrolledDescriptorType
>
struct
Layout
;
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
// params for Register memory
index_t
ScalarPerVector
// param for Register memory
>
typename
UnrolledDescriptorType
>
struct
Tensor
;
template
<
typename
FromType
,
typename
ToType
>
...
...
@@ -45,13 +42,22 @@ struct Slice
__host__
__device__
constexpr
Slice
()
:
from_
(),
to_
()
{}
__host__
__device__
constexpr
Slice
(
FromType
from
,
ToType
to
)
:
from_
(
from
),
to_
(
to
)
{}
/**
* \brief Calculate slice range.
*
* \param dim Dimension size.
* \return Slice range.
*/
template
<
typename
T
>
__host__
__device__
constexpr
auto
range
(
const
T
&
dim
)
const
{
if
constexpr
(
is_same_v
<
FromType
,
index_t
>
||
is_same_v
<
ToType
,
index_t
>
||
is_same_v
<
T
,
index_t
>
)
{
assert
(
dim
>=
to_
&&
from_
>=
0
&&
(
to_
<
0
||
to_
>
from_
)
&&
"Invalid range"
);
if
(
!
(
dim
>=
to_
&&
from_
>=
0
&&
(
to_
<
0
||
to_
>
from_
)))
{
throw
std
::
runtime_error
(
"Invalid range"
);
}
if
(
to_
<
0
)
{
return
dim
-
from_
+
to_
+
1
;
...
...
@@ -98,33 +104,30 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template
<
MemoryTypeEnum
MemoryType
,
typename
ElementType
,
typename
Shape
,
typename
Strides
>
constexpr
auto
make_tensor
(
ElementType
*
pointer
,
const
Layout
<
Shape
,
Strides
>&
layout
)
template
<
MemoryTypeEnum
MemoryType
,
typename
ElementType
,
typename
Shape
,
typename
UnrolledDescriptorType
>
constexpr
auto
make_tensor
(
ElementType
*
pointer
,
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
Strides
,
0
/*NumVectors*/
,
0
/*ScalarPerVector*/
>
(
pointer
,
layout
);
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
UnrolledDescriptorType
>
(
pointer
,
layout
);
}
/**
* \brief Make SGPR or VGPR tensor function.
*
* \tparam MemoryType Type of memory.
* \tparam NumVectors Number of vectors.
* \tparam ScalarPerVector Scalars per vector.
* \tparam ElementType Memory data type.
* \param layout Tensor layout.
* \return Constructed tensor.
*/
template
<
MemoryTypeEnum
MemoryType
,
index_t
NumVectors
,
index_t
ScalarPerVector
,
typename
ElementType
,
typename
Shape
,
typename
Strides
>
constexpr
auto
make_register_tensor
(
const
Layout
<
Shape
,
Strides
>&
layout
)
typename
UnrolledDescriptorType
>
constexpr
auto
make_register_tensor
(
const
Layout
<
Shape
,
UnrolledDescriptorType
>&
layout
)
{
static_assert
(
!
IsNestedTuple
(
Shape
{}),
"Register tensor with nested layout is not supported"
);
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>
(
layout
);
return
Tensor
<
MemoryType
,
ElementType
,
Shape
,
UnrolledDescriptorType
>
(
layout
);
}
/**
...
...
@@ -136,12 +139,9 @@ constexpr auto make_register_tensor(const Layout<Shape, Strides>& layout)
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
const
auto
&
layout
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
layout
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
return
tensor
.
GetLayout
();
}
...
...
@@ -157,12 +157,9 @@ template <index_t... Idxs,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
index_t
size
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
size
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
return
size
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
...
...
@@ -178,12 +175,9 @@ template <index_t... Idxs,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
index_t
rank
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
rank
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
return
rank
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
...
...
@@ -199,35 +193,13 @@ template <index_t... Idxs,
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
index_t
depth
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
auto
depth
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
return
depth
<
Idxs
...
>
(
tensor
.
GetLayout
());
}
/**
* \brief Get Tensor strides.
*
* \param tensor Tensor to get strides from.
* \return Requsted strides.
*/
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
__host__
__device__
constexpr
const
auto
&
stride
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
{
return
stride
(
tensor
.
GetLayout
());
}
/**
* \brief Get Tensor shape.
*
...
...
@@ -237,12 +209,9 @@ stride(const Tensor<BufferAddressSpace, ElementType, Shape, Strides, NumVectors,
template
<
MemoryTypeEnum
BufferAddressSpace
,
typename
ElementType
,
typename
Shape
,
typename
Strides
,
index_t
NumVectors
,
index_t
ScalarPerVector
>
typename
UnrolledDescriptorType
>
__host__
__device__
constexpr
const
auto
&
shape
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
Strides
,
NumVectors
,
ScalarPerVector
>&
tensor
)
shape
(
const
Tensor
<
BufferAddressSpace
,
ElementType
,
Shape
,
UnrolledDescriptorType
>&
tensor
)
{
return
shape
(
tensor
.
GetLayout
());
}
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp
View file @
522b7aee
...
...
@@ -265,6 +265,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
return
0
;
}
throw
std
::
runtime_error
(
"Col2Img: number of dimensions should be between 1 and 3."
);
return
1
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp
View file @
522b7aee
...
...
@@ -313,6 +313,9 @@ struct ReferenceConvBwdData : public device::BaseOperator
return
0
;
}
throw
std
::
runtime_error
(
"Conv_bwd_data: number of dimensions must be between 1 and 3."
);
return
1
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp
View file @
522b7aee
...
...
@@ -265,6 +265,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
return
0
;
}
throw
std
::
runtime_error
(
"Conv_bwd: number of dimensions must be between 1 and 3."
);
return
1
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment